fix: Fix update agent (#2265)
This commit is contained in:
@@ -820,13 +820,13 @@ class SyncServer(Server):
|
||||
actor: User,
|
||||
) -> AgentState:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
# Update agent state in the db first
|
||||
self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Update tags
|
||||
if request.tags is not None: # Allow for empty list
|
||||
letta_agent.agent_state.tags = request.tags
|
||||
|
||||
# TODO: Everything below needs to get removed, no updating anything in memory
|
||||
# update the system prompt
|
||||
if request.system:
|
||||
letta_agent.update_system_prompt(request.system)
|
||||
@@ -840,42 +840,10 @@ class SyncServer(Server):
|
||||
|
||||
# tools
|
||||
if request.tool_ids:
|
||||
# Replace tools and also re-link
|
||||
letta_agent.link_tools(letta_agent.agent_state.tools)
|
||||
|
||||
# (1) get tools + make sure they exist
|
||||
# Current and target tools as sets of tool names
|
||||
current_tools = letta_agent.agent_state.tools
|
||||
current_tool_ids = set([t.id for t in current_tools])
|
||||
target_tool_ids = set(request.tool_ids)
|
||||
letta_agent.update_state()
|
||||
|
||||
# Calculate tools to add and remove
|
||||
tool_ids_to_add = target_tool_ids - current_tool_ids
|
||||
tools_ids_to_remove = current_tool_ids - target_tool_ids
|
||||
|
||||
# update agent tool list
|
||||
for tool_id in tools_ids_to_remove:
|
||||
self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
for tool_id in tool_ids_to_add:
|
||||
self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
|
||||
# reload agent
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# configs
|
||||
if request.llm_config:
|
||||
letta_agent.agent_state.llm_config = request.llm_config
|
||||
if request.embedding_config:
|
||||
letta_agent.agent_state.embedding_config = request.embedding_config
|
||||
|
||||
# other minor updates
|
||||
if request.name:
|
||||
letta_agent.agent_state.name = request.name
|
||||
if request.metadata_:
|
||||
letta_agent.agent_state.metadata_ = request.metadata_
|
||||
|
||||
# save the agent
|
||||
save_agent(letta_agent)
|
||||
# TODO: probably reload the agent somehow?
|
||||
return letta_agent.agent_state
|
||||
|
||||
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
|
||||
|
||||
@@ -21,7 +21,7 @@ from letta.schemas.user import User
|
||||
|
||||
utils.DEBUG = True
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -393,9 +393,7 @@ def test_load_data(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# create source
|
||||
passages_before = server.agent_manager.list_passages(
|
||||
actor=user, agent_id=agent_id, cursor=None, limit=10000
|
||||
)
|
||||
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_before) == 0
|
||||
|
||||
source = server.source_manager.create_source(
|
||||
@@ -494,7 +492,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
|
||||
latest = passages_1[0]
|
||||
latest = passages_1[0]
|
||||
earliest = passages_2[-1]
|
||||
|
||||
# test archival memory
|
||||
@@ -966,7 +964,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert initial_passage_count == 0
|
||||
|
||||
|
||||
# Create a source
|
||||
source = server.source_manager.create_source(
|
||||
PydanticSource(
|
||||
@@ -1079,3 +1076,26 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert len(passages2) == len(passages)
|
||||
assert any("chicken" in passage.text.lower() for passage in passages2)
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
|
||||
|
||||
|
||||
def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="memory_rebuild_test_agent",
|
||||
tool_ids=[],
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="The human's name is Bob."),
|
||||
CreateBlock(label="persona", value="My name is Alice."),
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
assert len(agent_state.tools) == 0
|
||||
agent_state = server.update_agent(agent_state.id, request=UpdateAgent(tool_ids=[b.id for b in base_tools]), actor=actor)
|
||||
assert len(agent_state.tools) == len(base_tools)
|
||||
|
||||
Reference in New Issue
Block a user