diff --git a/letta/client/client.py b/letta/client/client.py index 0bd65017..a8fc6f32 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1565,6 +1565,18 @@ class RESTClient(AbstractClient): # Parse and return the deleted organization return Organization(**response.json()) + def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory: + + # @router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit") + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/limit", + headers=self.headers, + json={"label": block_label, "limit": limit}, + ) + if response.status_code != 200: + raise ValueError(f"Failed to update agent memory limit: {response.text}") + return Memory(**response.json()) + class LocalClient(AbstractClient): """ @@ -2773,3 +2785,6 @@ class LocalClient(AbstractClient): def delete_org(self, org_id: str) -> Organization: return self.server.organization_manager.delete_organization_by_id(org_id=org_id) + + def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory: + return self.server.update_agent_memory_limit(user_id=self.user_id, agent_id=agent_id, block_label=block_label, limit=limit) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index eb516aba..b06962b8 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import Field, model_validator +from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from letta.schemas.letta_base import LettaBase @@ -117,6 +117,13 @@ class BlockUpdate(BaseBlock): extra = "ignore" # Ignores extra fields +class BlockLimitUpdate(BaseModel): + """Update the limit of a block""" + + label: str = Field(..., description="Label of the block.") + limit: int = Field(..., description="New limit of the block.") + + class UpdatePersona(BlockUpdate): """Update a persona block""" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 1ce7b4c7..9026f83e 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -167,6 +167,19 @@ class Memory(BaseModel, validate_assignment=True): self.memory[label].value = value + def update_block_limit(self, label: str, limit: int): + """Update the limit of a block""" + if label not in self.memory: + raise ValueError(f"Block with label {label} does not exist") + if not isinstance(limit, int): + raise ValueError(f"Provided limit must be an integer") + + # Check to make sure the new limit is greater than the current length of the block + if len(self.memory[label].value) > limit: + raise ValueError(f"New limit {limit} is less than the current length of the block {len(self.memory[label].value)}") + + self.memory[label].limit = limit + # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. class BasicBlockMemory(Memory): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index d100f9bf..7b4185c3 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from letta.schemas.block import BlockLimitUpdate from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( LegacyLettaMessage, @@ -217,7 +218,7 @@ def update_agent_memory( ): """ Update the core memory of a specific agent. - This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. + This endpoint accepts new memory contents (labels as keys, and values as values) and updates the core memory of the agent identified by the user ID and agent ID. """ actor = server.get_user_or_default(user_id=user_id) @@ -225,6 +226,27 @@ def update_agent_memory( return memory +@router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit") +def update_agent_memory_limit( + agent_id: str, + update_label: BlockLimitUpdate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Update the limit of a block in an agent's memory. + """ + actor = server.get_user_or_default(user_id=user_id) + + memory = server.update_agent_memory_limit( + user_id=actor.id, + agent_id=agent_id, + block_label=update_label.label, + limit=update_label.limit, + ) + return memory + + @router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary") def get_agent_recall_memory_summary( agent_id: str, diff --git a/letta/server/server.py b/letta/server/server.py index 309fb3a3..2f9f1da1 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1827,3 +1827,33 @@ class SyncServer(Server): # Get the current message letta_agent = self._get_or_load_agent(agent_id=agent_id) return letta_agent.get_context_window() + + def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory: + """Update the limit of a block in an agent's memory""" + + # Get the user + user = self.user_manager.get_user_by_id(user_id=user_id) + + # Link a block to an agent's memory + letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent.memory.update_block_limit(label=block_label, limit=limit) + assert block_label in letta_agent.memory.list_block_labels() + + # write out the update the database + self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(block_label), actor=user) + + # check that the block was updated + updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(block_label).id, actor=user) + assert updated_block and updated_block.limit == limit + + # Recompile the agent memory + letta_agent.rebuild_memory(force=True, ms=self.ms) + + # save agent + save_agent(letta_agent, self.ms) + + updated_agent = self.ms.get_agent(agent_id=agent_id) + if updated_agent is None: + raise ValueError(f"Agent with id {agent_id} not found after linking block") + assert updated_agent.memory.get_block(label=block_label).limit == limit + return updated_agent.memory diff --git a/tests/test_client.py b/tests/test_client.py index 4508b06e..5986f33b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -32,7 +32,7 @@ from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.services.tool_manager import ToolManager from letta.settings import model_settings -from letta.utils import get_utc_time +from letta.utils import create_random_username, get_utc_time from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -730,3 +730,34 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a # Verify all tags are removed final_tags = client.get_agent(agent_id=agent.id).tags assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" + + +def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can update the limit of a block in an agent's memory""" + + agent = client.create_agent(name=create_random_username()) + + try: + current_labels = agent.memory.list_block_labels() + example_label = current_labels[0] + example_new_limit = 1 + current_block = agent.memory.get_block(label=example_label) + current_block_length = len(current_block.value) + + assert example_new_limit != agent.memory.get_block(label=example_label).limit + assert example_new_limit < current_block_length + + # We expect this to throw a value error + with pytest.raises(ValueError): + client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit) + + # Now try the same thing with a higher limit + example_new_limit = current_block_length + 10000 + assert example_new_limit > current_block_length + client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit) + + updated_agent = client.get_agent(agent_id=agent.id) + assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit + + finally: + client.delete_agent(agent.id) diff --git a/tests/test_memory.py b/tests/test_memory.py index 3760f31a..a4443cab 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -105,3 +105,16 @@ def test_memory_jinja2_set_template(sample_memory: Memory): ) with pytest.raises(ValueError): sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure) + + +def test_update_block_limit(sample_memory: Memory): + """Test updating the limit of a block""" + + test_new_limit = 1000 + current_labels = sample_memory.list_block_labels() + test_old_label = current_labels[0] + + assert sample_memory.get_block(label=test_old_label).limit != test_new_limit + + sample_memory.update_block_limit(label=test_old_label, limit=test_new_limit) + assert sample_memory.get_block(label=test_old_label).limit == test_new_limit