chore: Clean up .load_agent usage (#2298)
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
"""Add cascading deletes for sources to agents
|
||||
|
||||
Revision ID: e78b4e82db30
|
||||
Revises: d6632deac81d
|
||||
Create Date: 2024-12-20 16:30:17.095888
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e78b4e82db30"
|
||||
down_revision: Union[str, None] = "d6632deac81d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("sources_agents_agent_id_fkey", "sources_agents", type_="foreignkey")
|
||||
op.drop_constraint("sources_agents_source_id_fkey", "sources_agents", type_="foreignkey")
|
||||
op.create_foreign_key(None, "sources_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
|
||||
op.create_foreign_key(None, "sources_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "sources_agents", type_="foreignkey")
|
||||
op.drop_constraint(None, "sources_agents", type_="foreignkey")
|
||||
op.create_foreign_key("sources_agents_source_id_fkey", "sources_agents", "sources", ["source_id"], ["id"])
|
||||
op.create_foreign_key("sources_agents_agent_id_fkey", "sources_agents", "agents", ["agent_id"], ["id"])
|
||||
# ### end Alembic commands ###
|
||||
@@ -44,7 +44,6 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
@@ -53,7 +52,6 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
)
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import (
|
||||
@@ -969,32 +967,6 @@ class Agent(BaseAgent):
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
def attach_source(
|
||||
self,
|
||||
user: PydanticUser,
|
||||
source_id: str,
|
||||
source_manager: SourceManager,
|
||||
agent_manager: AgentManager,
|
||||
):
|
||||
"""Attach a source to the agent using the SourcesAgents ORM relationship.
|
||||
|
||||
Args:
|
||||
user: User performing the action
|
||||
source_id: ID of the source to attach
|
||||
source_manager: SourceManager instance to verify source exists
|
||||
agent_manager: AgentManager instance to manage agent-source relationship
|
||||
"""
|
||||
# Verify source exists and user has permission to access it
|
||||
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
|
||||
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"
|
||||
|
||||
# Use the agent_manager to create the relationship
|
||||
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}.",
|
||||
)
|
||||
|
||||
def get_context_window(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
|
||||
|
||||
@@ -2987,7 +2987,11 @@ class LocalClient(AbstractClient):
|
||||
source_id (str): ID of the source
|
||||
source_name (str): Name of the source
|
||||
"""
|
||||
self.server.attach_source_to_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
|
||||
if source_name:
|
||||
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
|
||||
source_id = source.id
|
||||
|
||||
self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)
|
||||
|
||||
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
|
||||
"""
|
||||
@@ -2999,7 +3003,10 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
source (Source): Detached source
|
||||
"""
|
||||
return self.server.detach_source_from_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
|
||||
if source_name:
|
||||
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
|
||||
source_id = source.id
|
||||
return self.server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=self.user)
|
||||
|
||||
def list_sources(self) -> List[Source]:
|
||||
"""
|
||||
|
||||
@@ -11,10 +11,10 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.passage import SourcePassage
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import SourcePassage
|
||||
|
||||
|
||||
class Source(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -32,4 +32,11 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
||||
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent",
|
||||
secondary="sources_agents",
|
||||
back_populates="sources",
|
||||
lazy="selectin",
|
||||
cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted
|
||||
passive_deletes=True, # Allows the database to handle deletion of orphaned rows
|
||||
)
|
||||
|
||||
@@ -9,5 +9,5 @@ class SourcesAgents(Base):
|
||||
|
||||
__tablename__ = "sources_agents"
|
||||
|
||||
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
|
||||
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
|
||||
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
@@ -130,11 +130,8 @@ def attach_source_to_agent(
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
|
||||
return source
|
||||
server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor)
|
||||
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
|
||||
@@ -148,8 +145,8 @@ def detach_source_from_agent(
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
|
||||
server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
|
||||
|
||||
@@ -59,7 +59,7 @@ from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUp
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -303,11 +303,6 @@ class SyncServer(Server):
|
||||
self.block_manager.add_default_blocks(actor=self.default_user)
|
||||
self.tool_manager.upsert_base_tools(actor=self.default_user)
|
||||
|
||||
# If there is a default org/user
|
||||
# This logic may have to change in the future
|
||||
if settings.load_default_external_tools:
|
||||
self.add_default_external_tools(actor=self.default_user)
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||
if model_settings.openai_api_key:
|
||||
@@ -431,9 +426,6 @@ class SyncServer(Server):
|
||||
skip_verify=True,
|
||||
)
|
||||
|
||||
# save agent after step
|
||||
# save_agent(letta_agent)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in server._step: {e}")
|
||||
print(traceback.print_exc())
|
||||
@@ -944,11 +936,10 @@ class SyncServer(Server):
|
||||
agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
for agent_state in agent_states:
|
||||
agent_id = agent_state.id
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Attach source to agent
|
||||
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
||||
self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
|
||||
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||
assert new_passage_size >= curr_passage_size # in case empty files are added
|
||||
|
||||
@@ -973,56 +964,6 @@ class SyncServer(Server):
|
||||
passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
|
||||
return passage_count, document_count
|
||||
|
||||
def attach_source_to_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
source_id: Optional[str] = None,
|
||||
source_name: Optional[str] = None,
|
||||
) -> Source:
|
||||
# attach a data source to an agent
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
if source_id:
|
||||
data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
elif source_name:
|
||||
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||
else:
|
||||
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
||||
|
||||
assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"
|
||||
|
||||
# load agent
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# attach source to agent
|
||||
agent.attach_source(user=actor, source_id=data_source.id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
||||
|
||||
return data_source
|
||||
|
||||
def detach_source_from_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
source_id: Optional[str] = None,
|
||||
source_name: Optional[str] = None,
|
||||
) -> Source:
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
if source_id:
|
||||
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
elif source_name:
|
||||
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||
source_id = source.id
|
||||
else:
|
||||
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
||||
|
||||
# delete agent-source mapping
|
||||
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
# return back source data
|
||||
return source
|
||||
|
||||
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
||||
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
||||
return []
|
||||
@@ -1060,22 +1001,6 @@ class SyncServer(Server):
|
||||
|
||||
return sources_with_metadata
|
||||
|
||||
def add_default_external_tools(self, actor: User) -> bool:
|
||||
"""Add default langchain tools. Return true if successful, false otherwise."""
|
||||
success = True
|
||||
tool_creates = ToolCreate.load_default_langchain_tools()
|
||||
if tool_settings.composio_api_key:
|
||||
tool_creates += ToolCreate.load_default_composio_tools()
|
||||
for tool_create in tool_creates:
|
||||
try:
|
||||
self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor)
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
|
||||
warnings.warn(traceback.format_exc())
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
|
||||
|
||||
@@ -83,9 +83,6 @@ class Settings(BaseSettings):
|
||||
pg_pool_recycle: int = 1800 # When to recycle connections
|
||||
pg_echo: bool = False # Logging
|
||||
|
||||
# tools configuration
|
||||
load_default_external_tools: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def letta_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
@@ -56,10 +56,35 @@ def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
|
||||
return decorator_retry
|
||||
|
||||
|
||||
def retry_until_success(max_attempts=10, sleep_time_seconds=4):
|
||||
"""
|
||||
Decorator to retry a function until it succeeds or the maximum number of attempts is reached.
|
||||
|
||||
:param max_attempts: Maximum number of attempts to retry the function.
|
||||
:param sleep_time_seconds: Time to wait between attempts, in seconds.
|
||||
"""
|
||||
|
||||
def decorator_retry(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m")
|
||||
if attempt == max_attempts:
|
||||
raise
|
||||
time.sleep(sleep_time_seconds)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_retry
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# OPENAI TESTS
|
||||
# ======================================================================================================================
|
||||
@retry_until_threshold(threshold=0.75, max_attempts=4)
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_returns_valid_first_message():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_first_response_is_valid_for_llm_endpoint(filename)
|
||||
@@ -67,6 +92,7 @@ def test_openai_gpt_4o_returns_valid_first_message():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_returns_keyword():
|
||||
keyword = "banana"
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
@@ -75,6 +101,7 @@ def test_openai_gpt_4o_returns_keyword():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_uses_external_tool():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_agent_uses_external_tool(filename)
|
||||
@@ -82,6 +109,7 @@ def test_openai_gpt_4o_uses_external_tool():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_recall_chat_memory():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_agent_recall_chat_memory(filename)
|
||||
@@ -89,6 +117,7 @@ def test_openai_gpt_4o_recall_chat_memory():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_archival_memory_retrieval():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_agent_archival_memory_retrieval(filename)
|
||||
@@ -96,6 +125,7 @@ def test_openai_gpt_4o_archival_memory_retrieval():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_archival_memory_insert():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_agent_archival_memory_insert(filename)
|
||||
@@ -103,6 +133,7 @@ def test_openai_gpt_4o_archival_memory_insert():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_edit_core_memory():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_agent_edit_core_memory(filename)
|
||||
@@ -110,12 +141,14 @@ def test_openai_gpt_4o_edit_core_memory():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_openai_gpt_4o_summarize_memory():
|
||||
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
|
||||
response = check_agent_summarize_memory_simple(filename)
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_embedding_endpoint_openai():
|
||||
filename = os.path.join(embedding_config_dir, "openai_embed.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
@@ -362,10 +362,10 @@ def other_agent_id(server, user_id, base_tools):
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=actor)
|
||||
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
def test_error_on_nonexistent_agent(server, user, agent_id):
|
||||
try:
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
|
||||
raise Exception("user_message call should have failed")
|
||||
except (KeyError, ValueError) as e:
|
||||
# Error is expected
|
||||
@@ -375,9 +375,9 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(1)
|
||||
def test_user_message_memory(server, user_id, agent_id):
|
||||
def test_user_message_memory(server, user, agent_id):
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
|
||||
server.user_message(user_id=user.id, agent_id=agent_id, message="/memory")
|
||||
raise Exception("user_message call should have failed")
|
||||
except ValueError as e:
|
||||
# Error is expected
|
||||
@@ -385,13 +385,11 @@ def test_user_message_memory(server, user_id, agent_id):
|
||||
except:
|
||||
raise
|
||||
|
||||
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
|
||||
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")
|
||||
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
def test_load_data(server, user, agent_id):
|
||||
# create source
|
||||
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_before) == 0
|
||||
@@ -409,10 +407,10 @@ def test_load_data(server, user_id, agent_id):
|
||||
"Shishir loves indian food",
|
||||
]
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
server.load_data(user.id, connector, source.name)
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user)
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
@@ -425,9 +423,9 @@ def test_save_archival_memory(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(4)
|
||||
def test_user_message(server, user_id, agent_id):
|
||||
def test_user_message(server, user, agent_id):
|
||||
# add data into recall memory
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
@@ -435,21 +433,20 @@ def test_user_message(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
def test_get_recall_memory(server, org_id, user, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
actor = user
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=2)
|
||||
cursor1 = messages_1[-1].id
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=1000)
|
||||
messages_3[-1].id
|
||||
assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
|
||||
message_ids = [m.id for m in messages_3]
|
||||
@@ -458,13 +455,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(6)
|
||||
def test_get_archival_memory(server, user_id, agent_id):
|
||||
def test_get_archival_memory(server, user, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
actor = user
|
||||
|
||||
# List latest 2 passages
|
||||
passages_1 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
limit=2,
|
||||
@@ -474,7 +471,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# List next 3 passages (earliest 3)
|
||||
cursor1 = passages_1[-1].id
|
||||
passages_2 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
cursor=cursor1,
|
||||
@@ -483,7 +480,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# List all 5
|
||||
cursor2 = passages_1[0].created_at
|
||||
passages_3 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
end_date=cursor2,
|
||||
@@ -496,20 +493,20 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
earliest = passages_2[-1]
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
|
||||
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
|
||||
assert len(passage_1) == 1
|
||||
assert passage_1[0].text == "alpha"
|
||||
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert all("alpha" not in passage.text for passage in passage_2)
|
||||
# test safe empty return
|
||||
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
|
||||
def test_get_context_window_overview(server: SyncServer, user, agent_id):
|
||||
"""Test that the context window overview fetch works"""
|
||||
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
|
||||
overview = server.get_agent_context_window(agent_id=agent_id, actor=user)
|
||||
assert overview is not None
|
||||
|
||||
# Run some basic checks
|
||||
@@ -546,7 +543,7 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
|
||||
)
|
||||
|
||||
|
||||
def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="nonexistent_tools_agent",
|
||||
@@ -554,7 +551,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=server.user_manager.get_user_or_default(user_id),
|
||||
actor=user,
|
||||
)
|
||||
|
||||
# create another user in the same org
|
||||
@@ -566,14 +563,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
|
||||
def _test_get_messages_letta_format(
|
||||
server,
|
||||
user_id,
|
||||
user,
|
||||
agent_id,
|
||||
reverse=False,
|
||||
):
|
||||
"""Test mapping between messages and letta_messages with reverse=False."""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
@@ -582,7 +579,7 @@ def _test_get_messages_letta_format(
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
@@ -675,10 +672,10 @@ def _test_get_messages_letta_format(
|
||||
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
|
||||
|
||||
|
||||
def test_get_messages_letta_format(server, user_id, agent_id):
|
||||
def test_get_messages_letta_format(server, user, agent_id):
|
||||
# for reverse in [False, True]:
|
||||
for reverse in [False]:
|
||||
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
|
||||
_test_get_messages_letta_format(server, user, agent_id, reverse=reverse)
|
||||
|
||||
|
||||
EXAMPLE_TOOL_SOURCE = '''
|
||||
@@ -825,9 +822,9 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
|
||||
def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
actor = user
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
@@ -848,7 +845,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
|
||||
# At this stage, there should only be 1 system message inside of recall storage
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
agent_id=agent_state.id,
|
||||
limit=1000,
|
||||
# reverse=reverse,
|
||||
@@ -870,7 +867,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
assert num_system_messages == 1, (num_system_messages, all_messages)
|
||||
|
||||
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
|
||||
|
||||
# At this stage, there should be 2 system message inside of recall storage
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
|
||||
Reference in New Issue
Block a user