diff --git a/letta/server/server.py b/letta/server/server.py index 4a48b2a2..8c3a4312 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -820,13 +820,13 @@ class SyncServer(Server): actor: User, ) -> AgentState: """Update the agents core memory block, return the new state""" + # Update agent state in the db first + self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor) + # Get the agent object (loaded in memory) letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - # Update tags - if request.tags is not None: # Allow for empty list - letta_agent.agent_state.tags = request.tags - + # TODO: Everything below needs to get removed, no updating anything in memory # update the system prompt if request.system: letta_agent.update_system_prompt(request.system) @@ -840,42 +840,10 @@ class SyncServer(Server): # tools if request.tool_ids: - # Replace tools and also re-link + letta_agent.link_tools(letta_agent.agent_state.tools) - # (1) get tools + make sure they exist - # Current and target tools as sets of tool names - current_tools = letta_agent.agent_state.tools - current_tool_ids = set([t.id for t in current_tools]) - target_tool_ids = set(request.tool_ids) + letta_agent.update_state() - # Calculate tools to add and remove - tool_ids_to_add = target_tool_ids - current_tool_ids - tools_ids_to_remove = current_tool_ids - target_tool_ids - - # update agent tool list - for tool_id in tools_ids_to_remove: - self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) - for tool_id in tool_ids_to_add: - self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) - - # reload agent - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - - # configs - if request.llm_config: - letta_agent.agent_state.llm_config = request.llm_config - if request.embedding_config: - letta_agent.agent_state.embedding_config = request.embedding_config - - # other minor updates - if request.name: - letta_agent.agent_state.name = request.name - if request.metadata_: - letta_agent.agent_state.metadata_ = request.metadata_ - - # save the agent - save_agent(letta_agent) - # TODO: probably reload the agent somehow? return letta_agent.agent_state def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]: diff --git a/tests/test_server.py b/tests/test_server.py index 0718e6ce..56c132ec 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -21,7 +21,7 @@ from letta.schemas.user import User utils.DEBUG = True from letta.config import LettaConfig -from letta.schemas.agent import CreateAgent +from letta.schemas.agent import CreateAgent, UpdateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.job import Job as PydanticJob from letta.schemas.llm_config import LLMConfig @@ -393,9 +393,7 @@ def test_load_data(server, user_id, agent_id): user = server.user_manager.get_user_or_default(user_id=user_id) # create source - passages_before = server.agent_manager.list_passages( - actor=user, agent_id=agent_id, cursor=None, limit=10000 - ) + passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000) assert len(passages_before) == 0 source = server.source_manager.create_source( @@ -494,7 +492,7 @@ def test_get_archival_memory(server, user_id, agent_id): assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test - latest = passages_1[0] + latest = passages_1[0] earliest = passages_2[-1] # test archival memory @@ -966,7 +964,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) assert initial_passage_count == 0 - # Create a source source = server.source_manager.create_source( PydanticSource( @@ -1079,3 +1076,26 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot assert len(passages2) == len(passages) assert any("chicken" in passage.text.lower() for passage in passages2) assert any("Anna".lower() in passage.text.lower() for passage in passages2) + + +def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools): + """Test that the memory rebuild is generating the correct number of role=system messages""" + actor = server.user_manager.get_user_or_default(user_id) + # create agent + agent_state = server.create_agent( + request=CreateAgent( + name="memory_rebuild_test_agent", + tool_ids=[], + memory_blocks=[ + CreateBlock(label="human", value="The human's name is Bob."), + CreateBlock(label="persona", value="My name is Alice."), + ], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, + ), + actor=actor, + ) + assert len(agent_state.tools) == 0 + agent_state = server.update_agent(agent_state.id, request=UpdateAgent(tool_ids=[b.id for b in base_tools]), actor=actor) + assert len(agent_state.tools) == len(base_tools)