fix: Fix update agent (#2265)

This commit is contained in:
Matthew Zhou
2024-12-16 21:43:13 -08:00
committed by GitHub
parent 80f32b41f5
commit 4c12c712bc
2 changed files with 32 additions and 44 deletions

View File

@@ -21,7 +21,7 @@ from letta.schemas.user import User
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import Job as PydanticJob
from letta.schemas.llm_config import LLMConfig
@@ -393,9 +393,7 @@ def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
# create source
passages_before = server.agent_manager.list_passages(
actor=user, agent_id=agent_id, cursor=None, limit=10000
)
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_before) == 0
source = server.source_manager.create_source(
@@ -494,7 +492,7 @@ def test_get_archival_memory(server, user_id, agent_id):
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
@@ -966,7 +964,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
@@ -1079,3 +1076,26 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
# create agent
agent_state = server.create_agent(
request=CreateAgent(
name="memory_rebuild_test_agent",
tool_ids=[],
memory_blocks=[
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=actor,
)
assert len(agent_state.tools) == 0
agent_state = server.update_agent(agent_state.id, request=UpdateAgent(tool_ids=[b.id for b in base_tools]), actor=actor)
assert len(agent_state.tools) == len(base_tools)