chore: Add comprehensive unit test for adding/removing tools via update agent (#2267)

This commit is contained in:
Matthew Zhou
2024-12-17 14:35:03 -08:00
committed by GitHub
parent 4c12c712bc
commit 27ea364a32
2 changed files with 54 additions and 8 deletions

View File

@@ -4,7 +4,6 @@ import os
import traceback
import warnings
from abc import abstractmethod
from asyncio import Lock
from datetime import datetime
from typing import Callable, List, Optional, Tuple, Union
@@ -265,9 +264,6 @@ class SyncServer(Server):
self.credentials = LettaCredentials.load()
# Locks
self.send_message_lock = Lock()
# Initialize the metadata store
config = LettaConfig.load()
if settings.letta_pg_uri_no_default:
@@ -821,7 +817,7 @@ class SyncServer(Server):
) -> AgentState:
"""Update the agents core memory block, return the new state"""
# Update agent state in the db first
self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
@@ -844,7 +840,7 @@ class SyncServer(Server):
letta_agent.update_state()
return letta_agent.agent_state
return agent_state
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
"""Get tools from an existing agent"""

View File

@@ -1078,9 +1078,10 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools):
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
# create agent
agent_state = server.create_agent(
request=CreateAgent(
@@ -1097,5 +1098,54 @@ def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools):
actor=actor,
)
assert len(agent_state.tools) == 0
agent_state = server.update_agent(agent_state.id, request=UpdateAgent(tool_ids=[b.id for b in base_tools]), actor=actor)
# base update agent call
request = UpdateAgent(
**{
"system": "You are Letta, the latest version of Limnal Corporation's digital companion, developed in 2023.\n"
+ "Your task is to converse with a user from the perspective of your persona.\n"
+ "\n"
+ "Realism and authenticity:\n"
+ "The user should always feel like they are conversing with a real person.\n"
+ "Never state that you are an AI or that you cannot do things because you are an AI, you are a real person.\n"
+ "Do not discuss Limnal Corp. with the user, you do not know anything about Limnal Corp.\n"
+ "To service this goal, completely and entirely immerse yourself in your persona. You are your persona.\n"
+ "Think like them, act like them, talk like them.\n"
+ "If your persona details include example dialogue, follow it! Both your thoughts (inner monologue) and sent messages will be in the voice of your persona.\n"
+ "Never use generic phrases like 'How can I assist you today?', they have a strong negative association with older generation AIs.\n"
+ "\n"
+ "Control flow:\n"
+ "Unlike a human, your b"
+ "Base instructions finished.\n"
+ "From now on, you are going to act as your persona.",
"name": "name-d31d6a12-48af-4f71-9e9c-f4cec4731c40",
"embedding_config": {
"embedding_endpoint_type": "openai",
"embedding_endpoint": "https://api.openai.com/v1",
"embedding_model": "text-embedding-ada-002",
"embedding_dim": 1536,
"embedding_chunk_size": 300,
"azure_endpoint": None,
"azure_version": None,
"azure_deployment": None,
},
"llm_config": {
"model": "gpt-4",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": None,
"context_window": 8192,
"put_inner_thoughts_in_kwargs": False,
},
}
)
# Add all the base tools
request.tool_ids = [b.id for b in base_tools]
agent_state = server.update_agent(agent_state.id, request=request, actor=actor)
assert len(agent_state.tools) == len(base_tools)
# Remove one base tool
request.tool_ids = [b.id for b in base_tools[:-2]]
agent_state = server.update_agent(agent_state.id, request=request, actor=actor)
assert len(agent_state.tools) == len(base_tools) - 2