fix: Fix update agent (#2265)

This commit is contained in:
Matthew Zhou
2024-12-16 21:43:13 -08:00
committed by GitHub
parent 80f32b41f5
commit 4c12c712bc
2 changed files with 32 additions and 44 deletions

View File

@@ -820,13 +820,13 @@ class SyncServer(Server):
actor: User,
) -> AgentState:
"""Update the agents core memory block, return the new state"""
# Update agent state in the db first
self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# Update tags
if request.tags is not None: # Allow for empty list
letta_agent.agent_state.tags = request.tags
# TODO: Everything below needs to get removed, no updating anything in memory
# update the system prompt
if request.system:
letta_agent.update_system_prompt(request.system)
@@ -840,42 +840,10 @@ class SyncServer(Server):
# tools
if request.tool_ids:
# Replace tools and also re-link
letta_agent.link_tools(letta_agent.agent_state.tools)
# (1) get tools + make sure they exist
# Current and target tools as sets of tool names
current_tools = letta_agent.agent_state.tools
current_tool_ids = set([t.id for t in current_tools])
target_tool_ids = set(request.tool_ids)
letta_agent.update_state()
# Calculate tools to add and remove
tool_ids_to_add = target_tool_ids - current_tool_ids
tools_ids_to_remove = current_tool_ids - target_tool_ids
# update agent tool list
for tool_id in tools_ids_to_remove:
self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
for tool_id in tool_ids_to_add:
self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
# reload agent
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# configs
if request.llm_config:
letta_agent.agent_state.llm_config = request.llm_config
if request.embedding_config:
letta_agent.agent_state.embedding_config = request.embedding_config
# other minor updates
if request.name:
letta_agent.agent_state.name = request.name
if request.metadata_:
letta_agent.agent_state.metadata_ = request.metadata_
# save the agent
save_agent(letta_agent)
# TODO: probably reload the agent somehow?
return letta_agent.agent_state
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:

View File

@@ -21,7 +21,7 @@ from letta.schemas.user import User
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import Job as PydanticJob
from letta.schemas.llm_config import LLMConfig
@@ -393,9 +393,7 @@ def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
# create source
passages_before = server.agent_manager.list_passages(
actor=user, agent_id=agent_id, cursor=None, limit=10000
)
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_before) == 0
source = server.source_manager.create_source(
@@ -494,7 +492,7 @@ def test_get_archival_memory(server, user_id, agent_id):
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
@@ -966,7 +964,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
@@ -1079,3 +1076,26 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
# create agent
agent_state = server.create_agent(
request=CreateAgent(
name="memory_rebuild_test_agent",
tool_ids=[],
memory_blocks=[
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=actor,
)
assert len(agent_state.tools) == 0
agent_state = server.update_agent(agent_state.id, request=UpdateAgent(tool_ids=[b.id for b in base_tools]), actor=actor)
assert len(agent_state.tools) == len(base_tools)