diff --git a/letta/client/client.py b/letta/client/client.py index 0bd65017..6065b411 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1565,6 +1565,41 @@ class RESTClient(AbstractClient): # Parse and return the deleted organization return Organization(**response.json()) + def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: + + # @router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label") + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/label", + headers=self.headers, + json={"current_label": current_label, "new_label": new_label}, + ) + if response.status_code != 200: + raise ValueError(f"Failed to update agent memory label: {response.text}") + return Memory(**response.json()) + + def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory: + + # @router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block") + response = requests.post( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block", + headers=self.headers, + json=create_block.model_dump(), + ) + if response.status_code != 200: + raise ValueError(f"Failed to add agent memory block: {response.text}") + return Memory(**response.json()) + + def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: + + # @router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block") + response = requests.delete( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block/{block_label}", + headers=self.headers, + ) + if response.status_code != 200: + raise ValueError(f"Failed to remove agent memory block: {response.text}") + return Memory(**response.json()) + class LocalClient(AbstractClient): """ @@ -2773,3 +2808,18 @@ class LocalClient(AbstractClient): def delete_org(self, org_id: str) -> Organization: return self.server.organization_manager.delete_organization_by_id(org_id=org_id) + + def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: + return self.server.update_agent_memory_label( + user_id=self.user_id, agent_id=agent_id, current_block_label=current_label, new_block_label=new_label + ) + + def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory: + block_req = Block(**create_block.model_dump()) + block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req) + # Link the block to the agent + updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id) + return updated_memory + + def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: + return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index eb516aba..4680871f 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import Field, model_validator +from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from letta.schemas.letta_base import LettaBase @@ -95,6 +95,13 @@ class BlockCreate(BaseBlock): label: str = Field(..., description="Label of the block.") +class BlockLabelUpdate(BaseModel): + """Update the label of a block""" + + current_label: str = Field(..., description="Current label of the block.") + new_label: str = Field(..., description="New label of the block.") + + class CreatePersona(BlockCreate): """Create a persona block""" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 1ce7b4c7..751bdb54 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -158,6 +158,13 @@ class Memory(BaseModel, validate_assignment=True): self.memory[block.label] = block + def unlink_block(self, block_label: str) -> Block: + """Unlink a block from the memory object""" + if block_label not in self.memory: + raise ValueError(f"Block with label {block_label} does not exist") + + return self.memory.pop(block_label) + def update_block_value(self, label: str, value: str): """Update the value of a block""" if label not in self.memory: @@ -167,6 +174,19 @@ class Memory(BaseModel, validate_assignment=True): self.memory[label].value = value + def update_block_label(self, current_label: str, new_label: str): + """Update the label of a block""" + if current_label not in self.memory: + raise ValueError(f"Block with label {current_label} does not exist") + if not isinstance(new_label, str): + raise ValueError(f"Provided new label must be a string") + + # First change the label of the block + self.memory[current_label].label = new_label + + # Then swap the block to the new label + self.memory[new_label] = self.memory.pop(current_label) + # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. class BasicBlockMemory(Memory): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index d100f9bf..a602428c 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( LegacyLettaMessage, @@ -217,7 +218,8 @@ def update_agent_memory( ): """ Update the core memory of a specific agent. - This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. + This endpoint accepts new memory contents to update the core memory of the agent. + This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks. """ actor = server.get_user_or_default(user_id=user_id) @@ -225,6 +227,66 @@ def update_agent_memory( return memory +@router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label") +def update_agent_memory_label( + agent_id: str, + update_label: BlockLabelUpdate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Update the label of a block in an agent's memory. + """ + actor = server.get_user_or_default(user_id=user_id) + + memory = server.update_agent_memory_label( + user_id=actor.id, agent_id=agent_id, current_block_label=update_label.current_label, new_block_label=update_label.new_label + ) + return memory + + +@router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block") +def add_agent_memory_block( + agent_id: str, + create_block: BlockCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Creates a memory block and links it to the agent. + """ + actor = server.get_user_or_default(user_id=user_id) + + # Copied from POST /blocks + block_req = Block(**create_block.model_dump()) + block = server.block_manager.create_or_update_block(actor=actor, block=block_req) + + # Link the block to the agent + updated_memory = server.link_block_to_agent_memory(user_id=actor.id, agent_id=agent_id, block_id=block.id) + + return updated_memory + + +@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block") +def remove_agent_memory_block( + agent_id: str, + # TODO should this be block_id, or the label? + # I think label is OK since it's user-friendly + guaranteed to be unique within a Memory object + block_label: str, + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted. + """ + actor = server.get_user_or_default(user_id=user_id) + + # Unlink the block from the agent + updated_memory = server.unlink_block_from_agent_memory(user_id=actor.id, agent_id=agent_id, block_label=block_label) + + return updated_memory + + @router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary") def get_agent_recall_memory_summary( agent_id: str, diff --git a/letta/server/server.py b/letta/server/server.py index 309fb3a3..49abae20 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -409,8 +409,10 @@ class SyncServer(Server): logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}") raise - def _get_or_load_agent(self, agent_id: str) -> Agent: + def _get_or_load_agent(self, agent_id: str, caching: bool = True) -> Agent: """Check if the agent is in-memory, then load""" + + # Gets the agent state agent_state = self.ms.get_agent(agent_id=agent_id) if not agent_state: raise ValueError(f"Agent does not exist") @@ -418,11 +420,24 @@ class SyncServer(Server): actor = self.user_manager.get_user_by_id(user_id) logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}") - # TODO: consider disabling loading cached agents due to potential concurrency issues - letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id) - if not letta_agent: - logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") + if caching: + # TODO: consider disabling loading cached agents due to potential concurrency issues + letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id) + if not letta_agent: + logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") + letta_agent = self._load_agent(agent_id=agent_id, actor=actor) + else: + # This breaks unit tests in test_local_client.py letta_agent = self._load_agent(agent_id=agent_id, actor=actor) + + # letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id) + # if not letta_agent: + # logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") + + # NOTE: no longer caching, always forcing a lot from the database + # Loads the agent objects + # letta_agent = self._load_agent(agent_id=agent_id, actor=actor) + return letta_agent def _step( @@ -1441,6 +1456,7 @@ class SyncServer(Server): # If we modified the memory contents, we need to rebuild the memory block inside the system message if modified: letta_agent.rebuild_memory() + # letta_agent.rebuild_memory(force=True, ms=self.ms) # This breaks unit tests in test_local_client.py # save agent save_agent(letta_agent, self.ms) @@ -1827,3 +1843,89 @@ class SyncServer(Server): # Get the current message letta_agent = self._get_or_load_agent(agent_id=agent_id) return letta_agent.get_context_window() + + def update_agent_memory_label(self, user_id: str, agent_id: str, current_block_label: str, new_block_label: str) -> Memory: + """Update the label of a block in an agent's memory""" + + # Get the user + user = self.user_manager.get_user_by_id(user_id=user_id) + + # Link a block to an agent's memory + letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent.memory.update_block_label(current_label=current_block_label, new_label=new_block_label) + assert new_block_label in letta_agent.memory.list_block_labels() + self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(new_block_label), actor=user) + + # check that the block was updated + updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(new_block_label).id, actor=user) + + # Recompile the agent memory + letta_agent.rebuild_memory(force=True, ms=self.ms) + + # save agent + save_agent(letta_agent, self.ms) + + updated_agent = self.ms.get_agent(agent_id=agent_id) + if updated_agent is None: + raise ValueError(f"Agent with id {agent_id} not found after linking block") + assert new_block_label in updated_agent.memory.list_block_labels() + assert current_block_label not in updated_agent.memory.list_block_labels() + return updated_agent.memory + + def link_block_to_agent_memory(self, user_id: str, agent_id: str, block_id: str) -> Memory: + """Link a block to an agent's memory""" + + # Get the user + user = self.user_manager.get_user_by_id(user_id=user_id) + + # Get the block first + block = self.block_manager.get_block_by_id(block_id=block_id, actor=user) + if block is None: + raise ValueError(f"Block with id {block_id} not found") + + # Link a block to an agent's memory + letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent.memory.link_block(block=block) + assert block.label in letta_agent.memory.list_block_labels() + + # Recompile the agent memory + letta_agent.rebuild_memory(force=True, ms=self.ms) + + # save agent + save_agent(letta_agent, self.ms) + + updated_agent = self.ms.get_agent(agent_id=agent_id) + if updated_agent is None: + raise ValueError(f"Agent with id {agent_id} not found after linking block") + assert block.label in updated_agent.memory.list_block_labels() + + return updated_agent.memory + + def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_label: str, delete_if_no_ref: bool = True) -> Memory: + """Unlink a block from an agent's memory. If the block is not linked to any agent, delete it.""" + + # Get the user + user = self.user_manager.get_user_by_id(user_id=user_id) + + # Link a block to an agent's memory + letta_agent = self._get_or_load_agent(agent_id=agent_id) + unlinked_block = letta_agent.memory.unlink_block(block_label=block_label) + assert unlinked_block.label not in letta_agent.memory.list_block_labels() + + # Check if the block is linked to any other agent + # TODO needs reference counting GC to handle loose blocks + # block = self.block_manager.get_block_by_id(block_id=unlinked_block.id, actor=user) + # if block is None: + # raise ValueError(f"Block with id {block_id} not found") + + # Recompile the agent memory + letta_agent.rebuild_memory(force=True, ms=self.ms) + + # save agent + save_agent(letta_agent, self.ms) + + updated_agent = self.ms.get_agent(agent_id=agent_id) + if updated_agent is None: + raise ValueError(f"Agent with id {agent_id} not found after linking block") + assert unlinked_block.label not in updated_agent.memory.list_block_labels() + return updated_agent.memory diff --git a/tests/test_client.py b/tests/test_client.py index 4508b06e..191fa13e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -15,6 +15,7 @@ from letta.client.client import LocalClient, RESTClient from letta.constants import DEFAULT_PRESET from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState +from letta.schemas.block import BlockCreate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( @@ -32,7 +33,7 @@ from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.services.tool_manager import ToolManager from letta.settings import model_settings -from letta.utils import get_utc_time +from letta.utils import create_random_username, get_utc_time from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -730,3 +731,82 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a # Verify all tags are removed final_tags = client.get_agent(agent_id=agent.id).tags assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" + + +def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can update the label of a block in an agent's memory""" + + agent = client.create_agent(name=create_random_username()) + + try: + current_labels = agent.memory.list_block_labels() + example_label = current_labels[0] + example_new_label = "example_new_label" + assert example_new_label not in current_labels + + client.update_agent_memory_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label) + + updated_agent = client.get_agent(agent_id=agent.id) + assert example_new_label in updated_agent.memory.list_block_labels() + + finally: + client.delete_agent(agent.id) + + +def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can add and remove a block from an agent's memory""" + + agent = client.create_agent(name=create_random_username()) + + try: + current_labels = agent.memory.list_block_labels() + example_new_label = "example_new_label" + example_new_value = "example value" + assert example_new_label not in current_labels + + # Link a new memory block + client.add_agent_memory_block( + agent_id=agent.id, + create_block=BlockCreate( + label=example_new_label, + value=example_new_value, + limit=1000, + ), + ) + + updated_agent = client.get_agent(agent_id=agent.id) + assert example_new_label in updated_agent.memory.list_block_labels() + + # Now unlink the block + client.remove_agent_memory_block(agent_id=agent.id, block_label=example_new_label) + + updated_agent = client.get_agent(agent_id=agent.id) + assert example_new_label not in updated_agent.memory.list_block_labels() + + finally: + client.delete_agent(agent.id) + + +# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState): +# """Test that the token limit is enforced for the core memory blocks""" + +# # Create an agent +# new_agent = client.create_agent( +# name="test-core-memory-token-limits", +# tools=BASE_TOOLS, +# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=2000), +# ) + +# try: +# # Then intentionally set the limit to be extremely low +# client.update_agent( +# agent_id=new_agent.id, +# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=100), +# ) + +# # TODO we should probably not allow updating the core memory limit if + +# # TODO in which case we should modify this test to actually to a proper token counter check + +# finally: +# client.delete_agent(new_agent.id) diff --git a/tests/test_memory.py b/tests/test_memory.py index 3760f31a..e1c5e655 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,6 +1,7 @@ import pytest # Import the classes here, assuming the above definitions are in a module named memory_module +from letta.schemas.block import Block from letta.schemas.memory import ChatMemory, Memory @@ -105,3 +106,37 @@ def test_memory_jinja2_set_template(sample_memory: Memory): ) with pytest.raises(ValueError): sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure) + + +def test_link_unlink_block(sample_memory: Memory): + """Test linking and unlinking a block to the memory""" + + # Link a new block + + test_new_label = "test_new_label" + test_new_value = "test_new_value" + test_new_block = Block(label=test_new_label, value=test_new_value, limit=2000) + + current_labels = sample_memory.list_block_labels() + assert test_new_label not in current_labels + + sample_memory.link_block(block=test_new_block) + assert test_new_label in sample_memory.list_block_labels() + assert sample_memory.get_block(test_new_label).value == test_new_value + + # Unlink the block + sample_memory.unlink_block(block_label=test_new_label) + assert test_new_label not in sample_memory.list_block_labels() + + +def test_update_block_label(sample_memory: Memory): + """Test updating the label of a block""" + + test_new_label = "test_new_label" + current_labels = sample_memory.list_block_labels() + assert test_new_label not in current_labels + test_old_label = current_labels[0] + + sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label) + assert test_new_label in sample_memory.list_block_labels() + assert test_old_label not in sample_memory.list_block_labels()