chore: Clean up .load_agent usage (#2298)

This commit is contained in:
Matthew Zhou
2024-12-20 16:56:53 -08:00
committed by GitHub
parent a5b1aac1fd
commit 9ad5fd64cf
10 changed files with 134 additions and 164 deletions

View File

@@ -0,0 +1,35 @@
"""Add cascading deletes for sources to agents
Revision ID: e78b4e82db30
Revises: d6632deac81d
Create Date: 2024-12-20 16:30:17.095888
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "e78b4e82db30"
down_revision: Union[str, None] = "d6632deac81d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("sources_agents_agent_id_fkey", "sources_agents", type_="foreignkey")
op.drop_constraint("sources_agents_source_id_fkey", "sources_agents", type_="foreignkey")
op.create_foreign_key(None, "sources_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
op.create_foreign_key(None, "sources_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "sources_agents", type_="foreignkey")
op.drop_constraint(None, "sources_agents", type_="foreignkey")
op.create_foreign_key("sources_agents_source_id_fkey", "sources_agents", "sources", ["source_id"], ["id"])
op.create_foreign_key("sources_agents_agent_id_fkey", "sources_agents", "agents", ["agent_id"], ["id"])
# ### end Alembic commands ###

View File

@@ -44,7 +44,6 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
@@ -53,7 +52,6 @@ from letta.services.helpers.agent_manager_helper import (
)
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
@@ -969,32 +967,6 @@ class Agent(BaseAgent):
# TODO: recall memory
raise NotImplementedError()
def attach_source(
self,
user: PydanticUser,
source_id: str,
source_manager: SourceManager,
agent_manager: AgentManager,
):
"""Attach a source to the agent using the SourcesAgents ORM relationship.
Args:
user: User performing the action
source_id: ID of the source to attach
source_manager: SourceManager instance to verify source exists
agent_manager: AgentManager instance to manage agent-source relationship
"""
# Verify source exists and user has permission to access it
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"
# Use the agent_manager to create the relationship
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
printd(
f"Attached data source {source.name} to agent {self.agent_state.name}.",
)
def get_context_window(self) -> ContextWindowOverview:
"""Get the context window of the agent"""

View File

@@ -2987,7 +2987,11 @@ class LocalClient(AbstractClient):
source_id (str): ID of the source
source_name (str): Name of the source
"""
self.server.attach_source_to_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
if source_name:
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
source_id = source.id
self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
"""
@@ -2999,7 +3003,10 @@ class LocalClient(AbstractClient):
Returns:
source (Source): Detached source
"""
return self.server.detach_source_from_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
if source_name:
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
source_id = source.id
return self.server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=self.user)
def list_sources(self) -> List[Source]:
"""

View File

@@ -11,10 +11,10 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.source import Source as PydanticSource
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.file import FileMetadata
from letta.orm.passage import SourcePassage
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
class Source(SqlalchemyBase, OrganizationMixin):
@@ -32,4 +32,11 @@ class Source(SqlalchemyBase, OrganizationMixin):
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
agents: Mapped[List["Agent"]] = relationship(
"Agent",
secondary="sources_agents",
back_populates="sources",
lazy="selectin",
cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted
passive_deletes=True, # Allows the database to handle deletion of orphaned rows
)

View File

@@ -9,5 +9,5 @@ class SourcesAgents(Base):
__tablename__ = "sources_agents"
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), primary_key=True)

View File

@@ -130,11 +130,8 @@ def attach_source_to_agent(
Attach a data source to an existing agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
return source
server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor)
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
@@ -148,8 +145,8 @@ def detach_source_from_agent(
Detach a data source from an existing agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")

View File

@@ -59,7 +59,7 @@ from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUp
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -303,11 +303,6 @@ class SyncServer(Server):
self.block_manager.add_default_blocks(actor=self.default_user)
self.tool_manager.upsert_base_tools(actor=self.default_user)
# If there is a default org/user
# This logic may have to change in the future
if settings.load_default_external_tools:
self.add_default_external_tools(actor=self.default_user)
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key:
@@ -431,9 +426,6 @@ class SyncServer(Server):
skip_verify=True,
)
# save agent after step
# save_agent(letta_agent)
except Exception as e:
logger.error(f"Error in server._step: {e}")
print(traceback.print_exc())
@@ -944,11 +936,10 @@ class SyncServer(Server):
agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent_state in agent_states:
agent_id = agent_state.id
agent = self.load_agent(agent_id=agent_id, actor=actor)
# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added
@@ -973,56 +964,6 @@ class SyncServer(Server):
passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
return passage_count, document_count
def attach_source_to_agent(
self,
user_id: str,
agent_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# attach a data source to an agent
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
if source_id:
data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"
# load agent
agent = self.load_agent(agent_id=agent_id, actor=actor)
# attach source to agent
agent.attach_source(user=actor, source_id=data_source.id, source_manager=self.source_manager, agent_manager=self.agent_manager)
return data_source
def detach_source_from_agent(
self,
user_id: str,
agent_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
if source_id:
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source_id = source.id
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
# delete agent-source mapping
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
# return back source data
return source
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
return []
@@ -1060,22 +1001,6 @@ class SyncServer(Server):
return sources_with_metadata
def add_default_external_tools(self, actor: User) -> bool:
"""Add default langchain tools. Return true if successful, false otherwise."""
success = True
tool_creates = ToolCreate.load_default_langchain_tools()
if tool_settings.composio_api_key:
tool_creates += ToolCreate.load_default_composio_tools()
for tool_create in tool_creates:
try:
self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor)
except Exception as e:
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
warnings.warn(traceback.format_exc())
success = False
return success
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
"""Update the details of a message associated with an agent"""

View File

@@ -83,9 +83,6 @@ class Settings(BaseSettings):
pg_pool_recycle: int = 1800 # When to recycle connections
pg_echo: bool = False # Logging
# tools configuration
load_default_external_tools: Optional[bool] = None
@property
def letta_pg_uri(self) -> str:
if self.pg_uri:

View File

@@ -56,10 +56,35 @@ def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
return decorator_retry
def retry_until_success(max_attempts=10, sleep_time_seconds=4):
"""
Decorator to retry a function until it succeeds or the maximum number of attempts is reached.
:param max_attempts: Maximum number of attempts to retry the function.
:param sleep_time_seconds: Time to wait between attempts, in seconds.
"""
def decorator_retry(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m")
if attempt == max_attempts:
raise
time.sleep(sleep_time_seconds)
return wrapper
return decorator_retry
# ======================================================================================================================
# OPENAI TESTS
# ======================================================================================================================
@retry_until_threshold(threshold=0.75, max_attempts=4)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
@@ -67,6 +92,7 @@ def test_openai_gpt_4o_returns_valid_first_message():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_returns_keyword():
keyword = "banana"
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
@@ -75,6 +101,7 @@ def test_openai_gpt_4o_returns_keyword():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_uses_external_tool():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_uses_external_tool(filename)
@@ -82,6 +109,7 @@ def test_openai_gpt_4o_uses_external_tool():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_recall_chat_memory():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_recall_chat_memory(filename)
@@ -89,6 +117,7 @@ def test_openai_gpt_4o_recall_chat_memory():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_archival_memory_retrieval(filename)
@@ -96,6 +125,7 @@ def test_openai_gpt_4o_archival_memory_retrieval():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_archival_memory_insert():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_archival_memory_insert(filename)
@@ -103,6 +133,7 @@ def test_openai_gpt_4o_archival_memory_insert():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_edit_core_memory():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_edit_core_memory(filename)
@@ -110,12 +141,14 @@ def test_openai_gpt_4o_edit_core_memory():
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_summarize_memory():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_summarize_memory_simple(filename)
print(f"Got successful response from client: \n\n{response}")
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_embedding_endpoint_openai():
filename = os.path.join(embedding_config_dir, "openai_embed.json")
run_embedding_endpoint(filename)

View File

@@ -362,10 +362,10 @@ def other_agent_id(server, user_id, base_tools):
server.agent_manager.delete_agent(agent_state.id, actor=actor)
def test_error_on_nonexistent_agent(server, user_id, agent_id):
def test_error_on_nonexistent_agent(server, user, agent_id):
try:
fake_agent_id = str(uuid.uuid4())
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
# Error is expected
@@ -375,9 +375,9 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):
@pytest.mark.order(1)
def test_user_message_memory(server, user_id, agent_id):
def test_user_message_memory(server, user, agent_id):
try:
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
server.user_message(user_id=user.id, agent_id=agent_id, message="/memory")
raise Exception("user_message call should have failed")
except ValueError as e:
# Error is expected
@@ -385,13 +385,11 @@ def test_user_message_memory(server, user_id, agent_id):
except:
raise
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")
@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
def test_load_data(server, user, agent_id):
# create source
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_before) == 0
@@ -409,10 +407,10 @@ def test_load_data(server, user_id, agent_id):
"Shishir loves indian food",
]
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
server.load_data(user.id, connector, source.name)
# attach source
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user)
# check archival memory size
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
@@ -425,9 +423,9 @@ def test_save_archival_memory(server, user_id, agent_id):
@pytest.mark.order(4)
def test_user_message(server, user_id, agent_id):
def test_user_message(server, user, agent_id):
# add data into recall memory
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
@@ -435,21 +433,20 @@ def test_user_message(server, user_id, agent_id):
@pytest.mark.order(5)
def test_get_recall_memory(server, org_id, user_id, agent_id):
def test_get_recall_memory(server, org_id, user, agent_id):
# test recall memory cursor pagination
actor = server.user_manager.get_user_or_default(user_id=user_id)
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
actor = user
messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, after=cursor1, limit=1000)
messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=1000)
messages_3[-1].id
assert messages_3[-1].created_at >= messages_3[0].created_at
assert len(messages_3) == len(messages_1) + len(messages_2)
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
message_ids = [m.id for m in messages_3]
@@ -458,13 +455,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
@pytest.mark.order(6)
def test_get_archival_memory(server, user_id, agent_id):
def test_get_archival_memory(server, user, agent_id):
# test archival memory cursor pagination
user = server.user_manager.get_user_by_id(user_id=user_id)
actor = user
# List latest 2 passages
passages_1 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
limit=2,
@@ -474,7 +471,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
cursor=cursor1,
@@ -483,7 +480,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
@@ -496,20 +493,20 @@ def test_get_archival_memory(server, user_id, agent_id):
earliest = passages_2[-1]
# test archival memory
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
def test_get_context_window_overview(server: SyncServer, user, agent_id):
"""Test that the context window overview fetch works"""
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
overview = server.get_agent_context_window(agent_id=agent_id, actor=user)
assert overview is not None
# Run some basic checks
@@ -546,7 +543,7 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
)
def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
agent_state = server.create_agent(
request=CreateAgent(
name="nonexistent_tools_agent",
@@ -554,7 +551,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=server.user_manager.get_user_or_default(user_id),
actor=user,
)
# create another user in the same org
@@ -566,14 +563,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
def _test_get_messages_letta_format(
server,
user_id,
user,
agent_id,
reverse=False,
):
"""Test mapping between messages and letta_messages with reverse=False."""
messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
@@ -582,7 +579,7 @@ def _test_get_messages_letta_format(
assert all(isinstance(m, Message) for m in messages)
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
@@ -675,10 +672,10 @@ def _test_get_messages_letta_format(
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
def test_get_messages_letta_format(server, user_id, agent_id):
def test_get_messages_letta_format(server, user, agent_id):
# for reverse in [False, True]:
for reverse in [False]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
_test_get_messages_letta_format(server, user, agent_id, reverse=reverse)
EXAMPLE_TOOL_SOURCE = '''
@@ -825,9 +822,9 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
actor = user
# create agent
agent_state = server.create_agent(
request=CreateAgent(
@@ -848,7 +845,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
# At this stage, there should only be 1 system message inside of recall storage
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_state.id,
limit=1000,
# reverse=reverse,
@@ -870,7 +867,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
assert num_system_messages == 1, (num_system_messages, all_messages)
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
# At this stage, there should be 2 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()