From 1fe45372b0a2ed46cab7cc30ece6c0f5301a5c42 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 21 Nov 2024 20:19:22 -0800 Subject: [PATCH] fix: redo #2085 (#2087) --- letta/client/client.py | 15 +++++++++++ letta/schemas/block.py | 7 +++++ letta/schemas/memory.py | 13 +++++++++ letta/server/rest_api/routers/v1/agents.py | 24 ++++++++++++++++- letta/server/server.py | 30 +++++++++++++++++++++ tests/test_client.py | 31 ++++++++++++++++++++++ tests/test_memory.py | 13 +++++++++ 7 files changed, 132 insertions(+), 1 deletion(-) diff --git a/letta/client/client.py b/letta/client/client.py index 6065b411..8a3d0538 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1600,6 +1600,18 @@ class RESTClient(AbstractClient): raise ValueError(f"Failed to remove agent memory block: {response.text}") return Memory(**response.json()) + def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory: + + # @router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit") + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/limit", + headers=self.headers, + json={"label": block_label, "limit": limit}, + ) + if response.status_code != 200: + raise ValueError(f"Failed to update agent memory limit: {response.text}") + return Memory(**response.json()) + class LocalClient(AbstractClient): """ @@ -2823,3 +2835,6 @@ class LocalClient(AbstractClient): def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label) + + def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory: + return self.server.update_agent_memory_limit(user_id=self.user_id, agent_id=agent_id, block_label=block_label, limit=limit) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 4680871f..b3acc866 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -124,6 +124,13 @@ class BlockUpdate(BaseBlock): extra = "ignore" # Ignores extra fields +class BlockLimitUpdate(BaseModel): + """Update the limit of a block""" + + label: str = Field(..., description="Label of the block.") + limit: int = Field(..., description="New limit of the block.") + + class UpdatePersona(BlockUpdate): """Update a persona block""" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 751bdb54..82aae738 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -187,6 +187,19 @@ class Memory(BaseModel, validate_assignment=True): # Then swap the block to the new label self.memory[new_label] = self.memory.pop(current_label) + def update_block_limit(self, label: str, limit: int): + """Update the limit of a block""" + if label not in self.memory: + raise ValueError(f"Block with label {label} does not exist") + if not isinstance(limit, int): + raise ValueError(f"Provided limit must be an integer") + + # Check to make sure the new limit is greater than the current length of the block + if len(self.memory[label].value) > limit: + raise ValueError(f"New limit {limit} is less than the current length of the block {len(self.memory[label].value)}") + + self.memory[label].limit = limit + # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. class BasicBlockMemory(Memory): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index a602428c..bdc6a577 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState -from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate +from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate, BlockLimitUpdate from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( LegacyLettaMessage, @@ -218,6 +218,7 @@ def update_agent_memory( ): """ Update the core memory of a specific agent. + This endpoint accepts new memory contents (labels as keys, and values as values) and updates the core memory of the agent identified by the user ID and agent ID. This endpoint accepts new memory contents to update the core memory of the agent. This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks. """ @@ -287,6 +288,27 @@ def remove_agent_memory_block( return updated_memory +@router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit") +def update_agent_memory_limit( + agent_id: str, + update_label: BlockLimitUpdate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Update the limit of a block in an agent's memory. + """ + actor = server.get_user_or_default(user_id=user_id) + + memory = server.update_agent_memory_limit( + user_id=actor.id, + agent_id=agent_id, + block_label=update_label.label, + limit=update_label.limit, + ) + return memory + + @router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary") def get_agent_recall_memory_summary( agent_id: str, diff --git a/letta/server/server.py b/letta/server/server.py index 49abae20..2e6438a8 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1929,3 +1929,33 @@ class SyncServer(Server): raise ValueError(f"Agent with id {agent_id} not found after linking block") assert unlinked_block.label not in updated_agent.memory.list_block_labels() return updated_agent.memory + + def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory: + """Update the limit of a block in an agent's memory""" + + # Get the user + user = self.user_manager.get_user_by_id(user_id=user_id) + + # Link a block to an agent's memory + letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent.memory.update_block_limit(label=block_label, limit=limit) + assert block_label in letta_agent.memory.list_block_labels() + + # write out the update the database + self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(block_label), actor=user) + + # check that the block was updated + updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(block_label).id, actor=user) + assert updated_block and updated_block.limit == limit + + # Recompile the agent memory + letta_agent.rebuild_memory(force=True, ms=self.ms) + + # save agent + save_agent(letta_agent, self.ms) + + updated_agent = self.ms.get_agent(agent_id=agent_id) + if updated_agent is None: + raise ValueError(f"Agent with id {agent_id} not found after linking block") + assert updated_agent.memory.get_block(label=block_label).limit == limit + return updated_agent.memory diff --git a/tests/test_client.py b/tests/test_client.py index 191fa13e..5a84c5a3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -810,3 +810,34 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a # finally: # client.delete_agent(new_agent.id) + + +def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can update the limit of a block in an agent's memory""" + + agent = client.create_agent(name=create_random_username()) + + try: + current_labels = agent.memory.list_block_labels() + example_label = current_labels[0] + example_new_limit = 1 + current_block = agent.memory.get_block(label=example_label) + current_block_length = len(current_block.value) + + assert example_new_limit != agent.memory.get_block(label=example_label).limit + assert example_new_limit < current_block_length + + # We expect this to throw a value error + with pytest.raises(ValueError): + client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit) + + # Now try the same thing with a higher limit + example_new_limit = current_block_length + 10000 + assert example_new_limit > current_block_length + client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit) + + updated_agent = client.get_agent(agent_id=agent.id) + assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit + + finally: + client.delete_agent(agent.id) diff --git a/tests/test_memory.py b/tests/test_memory.py index e1c5e655..d8f2cc79 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -140,3 +140,16 @@ def test_update_block_label(sample_memory: Memory): sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label) assert test_new_label in sample_memory.list_block_labels() assert test_old_label not in sample_memory.list_block_labels() + + +def test_update_block_limit(sample_memory: Memory): + """Test updating the limit of a block""" + + test_new_limit = 1000 + current_labels = sample_memory.list_block_labels() + test_old_label = current_labels[0] + + assert sample_memory.get_block(label=test_old_label).limit != test_new_limit + + sample_memory.update_block_limit(label=test_old_label, limit=test_new_limit) + assert sample_memory.get_block(label=test_old_label).limit == test_new_limit