diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index ab221e01..3d2292f3 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -18,7 +18,7 @@ on: branches: [ main ] jobs: - run-integration-tests: + integ-run: runs-on: ubuntu-latest timeout-minutes: 15 strategy: @@ -27,6 +27,9 @@ jobs: integration_test_suite: - "integration_test_summarizer.py" - "integration_test_tool_execution_sandbox.py" + - "integration_test_offline_memory_agent.py" + - "integration_test_agent_tool_graph.py" + - "integration_test_o1_agent.py" services: qdrant: image: qdrant/qdrant diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e1f40e9b..43e66727 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,25 +13,27 @@ on: pull_request: jobs: - run-core-unit-tests: + unit-run: runs-on: ubuntu-latest timeout-minutes: 15 strategy: fail-fast: false matrix: test_suite: + - "test_vector_embeddings.py" - "test_client.py" - - "test_local_client.py" - "test_client_legacy.py" - "test_server.py" - - "test_managers.py" - - "test_o1_agent.py" - - "test_tool_rule_solver.py" - - "test_agent_tool_graph.py" - - "test_utils.py" - - "test_tool_schema_parsing.py" - "test_v1_routes.py" - - "test_offline_memory_agent.py" + - "test_local_client.py" + - "test_managers.py" + - "test_base_functions.py" + - "test_tool_schema_parsing.py" + - "test_tool_rule_solver.py" + - "test_memory.py" + - "test_utils.py" + - "test_stream_buffer_readers.py" + - "test_summarize.py" services: qdrant: image: qdrant/qdrant @@ -81,57 +83,3 @@ jobs: LETTA_SERVER_PASS: test_server_token run: | poetry run pytest -s -vv tests/${{ matrix.test_suite }} - - misc-unit-tests: - runs-on: ubuntu-latest - needs: run-core-unit-tests - services: - qdrant: - image: qdrant/qdrant - ports: - - 6333:6333 - postgres: - image: pgvector/pgvector:pg17 - ports: - - 5432:5432 - env: - POSTGRES_HOST_AUTH_METHOD: trust - POSTGRES_DB: postgres - POSTGRES_USER: postgres - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Setup Python, Poetry, and Dependencies - uses: packetcoders/action-setup-cache-python-poetry@main - with: - python-version: "3.12" - poetry-version: "1.8.2" - install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox" - - name: Migrate database - env: - LETTA_PG_PORT: 5432 - LETTA_PG_USER: postgres - LETTA_PG_PASSWORD: postgres - LETTA_PG_DB: postgres - LETTA_PG_HOST: localhost - run: | - psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector' - poetry run alembic upgrade head - - name: Run misc unit tests - env: - LETTA_PG_PORT: 5432 - LETTA_PG_USER: postgres - LETTA_PG_PASSWORD: postgres - LETTA_PG_DB: postgres - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} - run: | - poetry run pytest -s -vv -k "not test_offline_memory_agent.py and not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests diff --git a/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py b/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py new file mode 100644 index 00000000..d03652c8 --- /dev/null +++ b/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py @@ -0,0 +1,175 @@ +"""Migrate agents to orm + +Revision ID: d05669b60ebe +Revises: c5d964280dff +Create Date: 2024-12-12 10:25:31.825635 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d05669b60ebe" +down_revision: Union[str, None] = "c5d964280dff" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "sources_agents", + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("source_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["agent_id"], + ["agents.id"], + ), + sa.ForeignKeyConstraint( + ["source_id"], + ["sources.id"], + ), + sa.PrimaryKeyConstraint("agent_id", "source_id"), + ) + op.drop_index("agent_source_mapping_idx_user", table_name="agent_source_mapping") + op.drop_table("agent_source_mapping") + op.add_column("agents", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("agents", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("agents", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("agents", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + op.add_column("agents", sa.Column("organization_id", sa.String(), nullable=True)) + # Populate `organization_id` based on `user_id` + # Use a raw SQL query to update the organization_id + op.execute( + """ + UPDATE agents + SET organization_id = users.organization_id + FROM users + WHERE agents.user_id = users.id + """ + ) + op.alter_column("agents", "organization_id", nullable=False) + op.alter_column("agents", "name", existing_type=sa.VARCHAR(), nullable=True) + op.drop_index("agents_idx_user", table_name="agents") + op.create_unique_constraint("unique_org_agent_name", "agents", ["organization_id", "name"]) + op.create_foreign_key(None, "agents", "organizations", ["organization_id"], ["id"]) + op.drop_column("agents", "tool_names") + op.drop_column("agents", "user_id") + op.drop_constraint("agents_tags_organization_id_fkey", "agents_tags", type_="foreignkey") + op.drop_column("agents_tags", "_created_by_id") + op.drop_column("agents_tags", "_last_updated_by_id") + op.drop_column("agents_tags", "updated_at") + op.drop_column("agents_tags", "id") + op.drop_column("agents_tags", "is_deleted") + op.drop_column("agents_tags", "created_at") + op.drop_column("agents_tags", "organization_id") + op.create_unique_constraint("unique_agent_block", "blocks_agents", ["agent_id", "block_id"]) + op.drop_constraint("fk_block_id_label", "blocks_agents", type_="foreignkey") + op.create_foreign_key( + "fk_block_id_label", "blocks_agents", "block", ["block_id", "block_label"], ["id", "label"], initially="DEFERRED", deferrable=True + ) + op.drop_column("blocks_agents", "_created_by_id") + op.drop_column("blocks_agents", "_last_updated_by_id") + op.drop_column("blocks_agents", "updated_at") + op.drop_column("blocks_agents", "id") + op.drop_column("blocks_agents", "is_deleted") + op.drop_column("blocks_agents", "created_at") + op.drop_constraint("unique_tool_per_agent", "tools_agents", type_="unique") + op.create_unique_constraint("unique_agent_tool", "tools_agents", ["agent_id", "tool_id"]) + op.drop_constraint("fk_tool_id", "tools_agents", type_="foreignkey") + op.drop_constraint("tools_agents_agent_id_fkey", "tools_agents", type_="foreignkey") + op.create_foreign_key(None, "tools_agents", "tools", ["tool_id"], ["id"], ondelete="CASCADE") + op.create_foreign_key(None, "tools_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE") + op.drop_column("tools_agents", "_created_by_id") + op.drop_column("tools_agents", "tool_name") + op.drop_column("tools_agents", "_last_updated_by_id") + op.drop_column("tools_agents", "updated_at") + op.drop_column("tools_agents", "id") + op.drop_column("tools_agents", "is_deleted") + op.drop_column("tools_agents", "created_at") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "tools_agents", + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + ) + op.add_column( + "tools_agents", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False) + ) + op.add_column("tools_agents", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column( + "tools_agents", + sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + ) + op.add_column("tools_agents", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.add_column("tools_agents", sa.Column("tool_name", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column("tools_agents", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint(None, "tools_agents", type_="foreignkey") + op.drop_constraint(None, "tools_agents", type_="foreignkey") + op.create_foreign_key("tools_agents_agent_id_fkey", "tools_agents", "agents", ["agent_id"], ["id"]) + op.create_foreign_key("fk_tool_id", "tools_agents", "tools", ["tool_id"], ["id"]) + op.drop_constraint("unique_agent_tool", "tools_agents", type_="unique") + op.create_unique_constraint("unique_tool_per_agent", "tools_agents", ["agent_id", "tool_name"]) + op.add_column( + "blocks_agents", + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + ) + op.add_column( + "blocks_agents", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False) + ) + op.add_column("blocks_agents", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column( + "blocks_agents", + sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + ) + op.add_column("blocks_agents", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.add_column("blocks_agents", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint("fk_block_id_label", "blocks_agents", type_="foreignkey") + op.create_foreign_key("fk_block_id_label", "blocks_agents", "block", ["block_id", "block_label"], ["id", "label"]) + op.drop_constraint("unique_agent_block", "blocks_agents", type_="unique") + op.add_column("agents_tags", sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column( + "agents_tags", + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + ) + op.add_column( + "agents_tags", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False) + ) + op.add_column("agents_tags", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column( + "agents_tags", + sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + ) + op.add_column("agents_tags", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.add_column("agents_tags", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.create_foreign_key("agents_tags_organization_id_fkey", "agents_tags", "organizations", ["organization_id"], ["id"]) + op.add_column("agents", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column("agents", sa.Column("tool_names", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True)) + op.drop_constraint(None, "agents", type_="foreignkey") + op.drop_constraint("unique_org_agent_name", "agents", type_="unique") + op.create_index("agents_idx_user", "agents", ["user_id"], unique=False) + op.alter_column("agents", "name", existing_type=sa.VARCHAR(), nullable=False) + op.drop_column("agents", "organization_id") + op.drop_column("agents", "_last_updated_by_id") + op.drop_column("agents", "_created_by_id") + op.drop_column("agents", "is_deleted") + op.drop_column("agents", "updated_at") + op.create_table( + "agent_source_mapping", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint("id", name="agent_source_mapping_pkey"), + ) + op.create_index("agent_source_mapping_idx_user", "agent_source_mapping", ["user_id", "agent_id", "source_id"], unique=False) + op.drop_table("sources_agents") + # ### end Alembic commands ### diff --git a/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py b/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py index c05775eb..e8733313 100644 --- a/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py +++ b/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py @@ -30,7 +30,7 @@ def upgrade() -> None: op.execute("DELETE FROM tools") # ### commands auto generated by Alembic - please adjust! ### - op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True)) + op.add_column("agents", sa.Column("tool_rules", letta.orm.agent.ToolRulesColumn(), nullable=True)) op.alter_column("block", "name", new_column_name="template_name", nullable=True) op.add_column("organizations", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) op.add_column("organizations", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) diff --git a/examples/composio_tool_usage.py b/examples/composio_tool_usage.py index 877d3754..c3c81895 100644 --- a/examples/composio_tool_usage.py +++ b/examples/composio_tool_usage.py @@ -74,7 +74,7 @@ def main(): """ # Create an agent - agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tools=[tool.name]) + agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[tool.id]) print(f"Created agent: {agent.name} with ID {str(agent.id)}") # Send a message to the agent diff --git a/examples/docs/agent_advanced.py b/examples/docs/agent_advanced.py index 311aa92a..95da5c34 100644 --- a/examples/docs/agent_advanced.py +++ b/examples/docs/agent_advanced.py @@ -29,7 +29,7 @@ agent_state = client.create_agent( # whether to include base letta tools (default: True) include_base_tools=True, # list of additional tools (by name) to add to the agent - tools=[], + tool_ids=[], ) print(f"Created agent with name {agent_state.name} and unique ID {agent_state.id}") diff --git a/examples/docs/tools.py b/examples/docs/tools.py index b41fb501..837a7dda 100644 --- a/examples/docs/tools.py +++ b/examples/docs/tools.py @@ -36,7 +36,7 @@ print(f"Created tool with name {tool.name}") # create a new agent agent_state = client.create_agent( # create the agent with an additional tool - tools=[tool.name], + tool_ids=[tool.id], # add tool rules that terminate execution after specific tools tool_rules=[ # exit after roll_d20 is called @@ -45,7 +45,7 @@ agent_state = client.create_agent( TerminalToolRule(tool_name="send_message"), ], ) -print(f"Created agent with name {agent_state.name} with tools {agent_state.tool_names}") +print(f"Created agent with name {agent_state.name} with tools {[t.name for t in agent_state.tools]}") # Message an agent response = client.send_message(agent_id=agent_state.id, role="user", message="roll a dice") @@ -61,7 +61,8 @@ client.add_tool_to_agent(agent_id=agent_state.id, tool_id=tool.id) client.delete_agent(agent_id=agent_state.id) # create an agent with only a subset of default tools -agent_state = client.create_agent(include_base_tools=False, tools=[tool.name, "send_message"]) +send_message_tool = client.get_tool_id("send_message") +agent_state = client.create_agent(include_base_tools=False, tool_ids=[tool.id, send_message_tool]) # message the agent to search archival memory (will be unable to do so) response = client.send_message(agent_id=agent_state.id, role="user", message="search your archival memory") diff --git a/examples/langchain_tool_usage.py b/examples/langchain_tool_usage.py index eb207694..cf55d120 100644 --- a/examples/langchain_tool_usage.py +++ b/examples/langchain_tool_usage.py @@ -67,7 +67,9 @@ def main(): """ # Create an agent - agent_state = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tools=[tool_name]) + agent_state = client.create_agent( + name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[wikipedia_query_tool.id] + ) print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}") # Send a message to the agent diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py index aca7c4f8..45c56ec3 100644 --- a/examples/tool_rule_usage.py +++ b/examples/tool_rule_usage.py @@ -108,7 +108,7 @@ def main(): ] # 4. Create the agent - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) # 5. Ask for the final secret word response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") diff --git a/letta/__init__.py b/letta/__init__.py index 0a397aed..3936cbb7 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -4,7 +4,7 @@ __version__ = "0.6.4" from letta.client.client import LocalClient, RESTClient, create_client # imports for easier access -from letta.schemas.agent import AgentState, PersistedAgentState +from letta.schemas.agent import AgentState from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus diff --git a/letta/agent.py b/letta/agent.py index 4060c363..9c136da8 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -6,8 +6,6 @@ import warnings from abc import ABC, abstractmethod from typing import List, Literal, Optional, Tuple, Union -from tqdm import tqdm - from letta.constants import ( BASE_TOOLS, CLI_WARNING_PREFIX, @@ -30,7 +28,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes from letta.memory import summarize_messages from letta.metadata import MetadataStore from letta.orm import User -from letta.schemas.agent import AgentState, AgentStepResponse +from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole @@ -49,12 +47,12 @@ from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser +from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox -from letta.services.user_manager import UserManager from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import ( get_heartbeat, @@ -316,7 +314,7 @@ class Agent(BaseAgent): else: printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}") - assert self.agent_state.id is not None and self.agent_state.user_id is not None + assert self.agent_state.id is not None and self.agent_state.created_by_id is not None # Generate a sequence of initial messages to put in the buffer init_messages = initialize_message_sequence( @@ -335,7 +333,7 @@ class Agent(BaseAgent): # We always need the system prompt up front system_message_obj = Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=init_messages[0], ) @@ -358,7 +356,7 @@ class Agent(BaseAgent): # Cast to Message objects init_messages = [ Message.dict_to_message( - agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg + agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg ) for msg in init_messages ] @@ -439,11 +437,12 @@ class Agent(BaseAgent): else: # execute tool in a sandbox # TODO: allow agent_state to specify which sandbox to execute tools in - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run( + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.created_by_id).run( agent_state=self.agent_state.__deepcopy__() ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + self.update_memory_if_change(updated_agent_state.memory) except Exception as e: # Need to catch error here, or else trunction wont happen @@ -573,7 +572,7 @@ class Agent(BaseAgent): added_messages_objs = [ Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg, ) @@ -603,7 +602,7 @@ class Agent(BaseAgent): response = create( llm_config=self.agent_state.llm_config, messages=message_sequence, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, functions=allowed_functions, functions_python=self.functions_python, function_call=function_call, @@ -689,7 +688,7 @@ class Agent(BaseAgent): Message.dict_to_message( id=response_message_id, agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=response_message.model_dump(), ) @@ -722,7 +721,7 @@ class Agent(BaseAgent): messages.append( Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "tool", @@ -745,7 +744,7 @@ class Agent(BaseAgent): messages.append( Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "tool", @@ -823,7 +822,7 @@ class Agent(BaseAgent): messages.append( Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "tool", @@ -842,7 +841,7 @@ class Agent(BaseAgent): messages.append( Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "tool", @@ -861,7 +860,7 @@ class Agent(BaseAgent): Message.dict_to_message( id=response_message_id, agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=response_message.model_dump(), ) @@ -920,7 +919,7 @@ class Agent(BaseAgent): # logger.debug("Saving agent state") # save updated state if ms: - save_agent(self, ms) + save_agent(self) # Chain stops if not chaining: @@ -931,10 +930,10 @@ class Agent(BaseAgent): break # Chain handlers elif token_warning: - assert self.agent_state.user_id is not None + assert self.agent_state.created_by_id is not None next_input_message = Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "user", # TODO: change to system? @@ -943,10 +942,10 @@ class Agent(BaseAgent): ) continue # always chain elif function_failed: - assert self.agent_state.user_id is not None + assert self.agent_state.created_by_id is not None next_input_message = Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "user", # TODO: change to system? @@ -955,10 +954,10 @@ class Agent(BaseAgent): ) continue # always chain elif heartbeat_request: - assert self.agent_state.user_id is not None + assert self.agent_state.created_by_id is not None next_input_message = Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={ "role": "user", # TODO: change to system? @@ -1129,10 +1128,10 @@ class Agent(BaseAgent): openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name} # Create the associated Message object (in the database) - assert self.agent_state.user_id is not None, "User ID is not set" + assert self.agent_state.created_by_id is not None, "User ID is not set" user_message = Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=openai_message_dict, # created_at=timestamp, @@ -1232,7 +1231,7 @@ class Agent(BaseAgent): [ Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=packed_summary_message, ) @@ -1260,7 +1259,7 @@ class Agent(BaseAgent): assert isinstance(new_system_message, str) new_system_message_obj = Message.dict_to_message( agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, + user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict={"role": "system", "content": new_system_message}, ) @@ -1371,7 +1370,14 @@ class Agent(BaseAgent): # TODO: recall memory raise NotImplementedError() - def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore, page_size: Optional[int] = None): + def attach_source( + self, + user: PydanticUser, + source_id: str, + source_manager: SourceManager, + agent_manager: AgentManager, + page_size: Optional[int] = None, + ): """Attach data with name `source_name` to the agent from source_connector.""" # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size) @@ -1384,7 +1390,7 @@ class Agent(BaseAgent): agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size) passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id) assert all([p.agent_id == self.agent_state.id for p in agents_passages]) - assert len(agents_passages) == passage_size # sanity check + assert len(agents_passages) == passage_size # sanity check assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}" # attach to agent @@ -1393,7 +1399,7 @@ class Agent(BaseAgent): # NOTE: need this redundant line here because we haven't migrated agent to ORM yet # TODO: delete @matt and remove - ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id) + agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user) printd( f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.", @@ -1610,20 +1616,31 @@ class Agent(BaseAgent): return context_window_breakdown.context_window_size_current -def save_agent(agent: Agent, ms: MetadataStore): +def save_agent(agent: Agent): """Save agent to metadata store""" - agent.update_state() agent_state = agent.agent_state assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" # TODO: move this to agent manager + # TODO: Completely strip out metadata # convert to persisted model - persisted_agent_state = agent.agent_state.to_persisted_agent_state() - if ms.get_agent(agent_id=persisted_agent_state.id): - ms.update_agent(persisted_agent_state) - else: - ms.create_agent(persisted_agent_state) + agent_manager = AgentManager() + update_agent = UpdateAgent( + name=agent_state.name, + tool_ids=[t.id for t in agent_state.tools], + source_ids=[s.id for s in agent_state.sources], + block_ids=[b.id for b in agent_state.memory.blocks], + tags=agent_state.tags, + system=agent_state.system, + tool_rules=agent_state.tool_rules, + llm_config=agent_state.llm_config, + embedding_config=agent_state.embedding_config, + message_ids=agent_state.message_ids, + description=agent_state.description, + metadata_=agent_state.metadata_, + ) + agent_manager.update_agent(agent_id=agent_state.id, agent_update=update_agent, actor=agent.user) def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: diff --git a/letta/chat_only_agent.py b/letta/chat_only_agent.py index 2051a547..eb029e93 100644 --- a/letta/chat_only_agent.py +++ b/letta/chat_only_agent.py @@ -2,7 +2,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import List, Optional, Union from letta.agent import Agent - from letta.interface import AgentInterface from letta.metadata import MetadataStore from letta.prompts import gpt_system @@ -68,8 +67,10 @@ class ChatOnlyAgent(Agent): name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000 ) - recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit:] - conversation_messages_block = Block(name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit) + recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :] + conversation_messages_block = Block( + name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit + ) offline_memory = BasicBlockMemory( blocks=[ @@ -89,7 +90,7 @@ class ChatOnlyAgent(Agent): memory=offline_memory, llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - tools=self.agent_state.metadata_.get("offline_memory_tools", []), + tool_ids=self.agent_state.metadata_.get("offline_memory_tools", []), include_base_tools=False, ) self.offline_memory_agent.memory.update_block_value(label="conversation_block", value=recent_convo) diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 81e27c9a..e941589d 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -140,7 +140,6 @@ def run( # read user id from config ms = MetadataStore(config) client = create_client() - server = client.server # determine agent to use, if not provided if not yes and not agent: @@ -165,8 +164,6 @@ def run( persona = persona if persona else config.persona if agent and agent_state: # use existing agent typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN) - # agent_config = AgentConfig.load(agent) - # agent_state = ms.get_agent(agent_name=agent, user_id=user_id) printd("Loading agent state:", agent_state.id) printd("Agent state:", agent_state.name) # printd("State path:", agent_config.save_state_dir()) @@ -224,8 +221,6 @@ def run( ) # create agent - tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tool_names] - agent_state.tools = tools letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user) else: # create new agent @@ -317,7 +312,7 @@ def run( metadata=metadata, ) assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}" - typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tool_names])}", fg=typer.colors.WHITE) + typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t.name for t in agent_state.tools])}", fg=typer.colors.WHITE) letta_agent = Agent( interface=interface(), @@ -326,7 +321,7 @@ def run( first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, user=client.user, ) - save_agent(agent=letta_agent, ms=ms) + save_agent(agent=letta_agent) typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN) # start event loop diff --git a/letta/client/client.py b/letta/client/client.py index 5cf6f9ad..d3259214 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -15,7 +15,8 @@ from letta.constants import ( ) from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code -from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState +from letta.orm.errors import NoResultFound +from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig @@ -65,10 +66,8 @@ def create_client(base_url: Optional[str] = None, token: Optional[str] = None): class AbstractClient(object): def __init__( self, - auto_save: bool = False, debug: bool = False, ): - self.auto_save = auto_save self.debug = debug def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: @@ -81,8 +80,9 @@ class AbstractClient(object): embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, memory=None, + block_ids: Optional[List[str]] = None, system: Optional[str] = None, - tools: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, @@ -97,7 +97,7 @@ class AbstractClient(object): name: Optional[str] = None, description: Optional[str] = None, system: Optional[str] = None, - tools: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, metadata: Optional[Dict] = None, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, @@ -436,7 +436,6 @@ class RESTClient(AbstractClient): Initializes a new instance of Client class. Args: - auto_save (bool): Whether to automatically save changes. user_id (str): The user ID. debug (bool): Whether to print debug information. default_llm_config (Optional[LLMConfig]): The default LLM configuration. @@ -456,6 +455,7 @@ class RESTClient(AbstractClient): params = {} if tags: params["tags"] = tags + params["match_all_tags"] = False response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params) return [AgentState(**agent) for agent in response.json()] @@ -491,10 +491,12 @@ class RESTClient(AbstractClient): llm_config: LLMConfig = None, # memory memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), + # Existing blocks + block_ids: Optional[List[str]] = None, # system system: Optional[str] = None, # tools - tools: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, # metadata @@ -511,7 +513,7 @@ class RESTClient(AbstractClient): llm_config (LLMConfig): LLM configuration memory (Memory): Memory configuration system (str): System configuration - tools (List[str]): List of tools + tool_ids (List[str]): List of tool ids include_base_tools (bool): Include base tools metadata (Dict): Metadata description (str): Description @@ -520,31 +522,54 @@ class RESTClient(AbstractClient): Returns: agent_state (AgentState): State of the created agent """ + tool_ids = tool_ids or [] tool_names = [] - if tools: - tool_names += tools if include_base_tools: tool_names += BASE_TOOLS tool_names += BASE_MEMORY_TOOLS + tool_ids += [self.get_tool_id(tool_name=name) for name in tool_names] assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" assert llm_config or self._default_llm_config, f"LLM config must be provided" + # TODO: This should not happen here, we need to have clear separation between create/add blocks + # TODO: This is insanely hacky and a result of allowing free-floating blocks + # TODO: When we create the block, it gets it's own block ID + blocks = [] + for block in memory.get_blocks(): + blocks.append( + self.create_block( + label=block.label, + value=block.value, + limit=block.limit, + template_name=block.template_name, + is_template=block.is_template, + ) + ) + memory.blocks = blocks + block_ids = block_ids or [] + # create agent - request = CreateAgent( - name=name, - description=description, - metadata_=metadata, - memory_blocks=[], - tools=tool_names, - tool_rules=tool_rules, - system=system, - agent_type=agent_type, - llm_config=llm_config if llm_config else self._default_llm_config, - embedding_config=embedding_config if embedding_config else self._default_embedding_config, - initial_message_sequence=initial_message_sequence, - tags=tags, - ) + create_params = { + "description": description, + "metadata_": metadata, + "memory_blocks": [], + "block_ids": [b.id for b in memory.get_blocks()] + block_ids, + "tool_ids": tool_ids, + "tool_rules": tool_rules, + "system": system, + "agent_type": agent_type, + "llm_config": llm_config if llm_config else self._default_llm_config, + "embedding_config": embedding_config if embedding_config else self._default_embedding_config, + "initial_message_sequence": initial_message_sequence, + "tags": tags, + } + + # Only add name if it's not None + if name is not None: + create_params["name"] = name + + request = CreateAgent(**create_params) # Use model_dump_json() instead of model_dump() # If we use model_dump(), the datetime objects will not be serialized correctly @@ -561,14 +586,6 @@ class RESTClient(AbstractClient): # gather agent state agent_state = AgentState(**response.json()) - # create and link blocks - for block in memory.get_blocks(): - if not self.get_block(block.id): - # note: this does not update existing blocks - # WARNING: this resets the block ID - this method is a hack for backwards compat, should eventually use CreateBlock not Memory - block = self.create_block(label=block.label, value=block.value, limit=block.limit) - self.link_agent_memory_block(agent_id=agent_state.id, block_id=block.id) - # refresh and return agent return self.get_agent(agent_state.id) @@ -602,7 +619,7 @@ class RESTClient(AbstractClient): name: Optional[str] = None, description: Optional[str] = None, system: Optional[str] = None, - tool_names: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, metadata: Optional[Dict] = None, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, @@ -617,7 +634,7 @@ class RESTClient(AbstractClient): name (str): Name of the agent description (str): Description of the agent system (str): System configuration - tool_names (List[str]): List of tools + tool_ids (List[str]): List of tools metadata (Dict): Metadata llm_config (LLMConfig): LLM configuration embedding_config (EmbeddingConfig): Embedding configuration @@ -627,11 +644,10 @@ class RESTClient(AbstractClient): Returns: agent_state (AgentState): State of the updated agent """ - request = UpdateAgentState( - id=agent_id, + request = UpdateAgent( name=name, system=system, - tool_names=tool_names, + tool_ids=tool_ids, tags=tags, description=description, metadata_=metadata, @@ -742,7 +758,7 @@ class RESTClient(AbstractClient): agents = [AgentState(**agent) for agent in response.json()] if len(agents) == 0: return None - agents = [agents[0]] # TODO: @matt monkeypatched + agents = [agents[0]] # TODO: @matt monkeypatched assert len(agents) == 1, f"Multiple agents with the same name: {[(agents.name, agents.id) for agents in agents]}" return agents[0].id @@ -1052,7 +1068,7 @@ class RESTClient(AbstractClient): raise ValueError(f"Failed to update block: {response.text}") return Block(**response.json()) - def get_block(self, block_id: str) -> Block: + def get_block(self, block_id: str) -> Optional[Block]: response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers) if response.status_code == 404: return None @@ -1607,23 +1623,6 @@ class RESTClient(AbstractClient): raise ValueError(f"Failed to get tool: {response.text}") return Tool(**response.json()) - def get_tool_id(self, name: str) -> Optional[str]: - """ - Get a tool ID by its name. - - Args: - id (str): ID of the tool - - Returns: - tool (Tool): Tool - """ - response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers) - if response.status_code == 404: - return None - elif response.status_code != 200: - raise ValueError(f"Failed to get tool: {response.text}") - return response.json() - def set_default_llm_config(self, llm_config: LLMConfig): """ Set the default LLM configuration @@ -2006,7 +2005,6 @@ class LocalClient(AbstractClient): A local client for Letta, which corresponds to a single user. Attributes: - auto_save (bool): Whether to automatically save changes. user_id (str): The user ID. debug (bool): Whether to print debug information. interface (QueuingInterface): The interface for the client. @@ -2015,7 +2013,6 @@ class LocalClient(AbstractClient): def __init__( self, - auto_save: bool = False, user_id: Optional[str] = None, org_id: Optional[str] = None, debug: bool = False, @@ -2026,11 +2023,9 @@ class LocalClient(AbstractClient): Initializes a new instance of Client class. Args: - auto_save (bool): Whether to automatically save changes. user_id (str): The user ID. debug (bool): Whether to print debug information. """ - self.auto_save = auto_save # set logging levels letta.utils.DEBUG = debug @@ -2056,14 +2051,14 @@ class LocalClient(AbstractClient): # get default user self.user_id = self.server.user_manager.DEFAULT_USER_ID - self.user = self.server.get_user_or_default(self.user_id) + self.user = self.server.user_manager.get_user_or_default(self.user_id) self.organization = self.server.get_organization_or_default(self.org_id) # agents def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]: self.interface.clear() - return self.server.list_agents(user_id=self.user_id, tags=tags) + return self.server.agent_manager.list_agents(actor=self.user, tags=tags) def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: """ @@ -2097,6 +2092,7 @@ class LocalClient(AbstractClient): llm_config: LLMConfig = None, # memory memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), + block_ids: Optional[List[str]] = None, # TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API) # memory_blocks=[ # {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000}, @@ -2105,7 +2101,7 @@ class LocalClient(AbstractClient): # system system: Optional[str] = None, # tools - tools: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, # metadata @@ -2132,55 +2128,53 @@ class LocalClient(AbstractClient): Returns: agent_state (AgentState): State of the created agent """ - - if name and self.agent_exists(agent_name=name): - raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})") - # construct list of tools + tool_ids = tool_ids or [] tool_names = [] - if tools: - tool_names += tools if include_base_tools: tool_names += BASE_TOOLS tool_names += BASE_MEMORY_TOOLS + tool_ids += [self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user).id for name in tool_names] # check if default configs are provided assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" assert llm_config or self._default_llm_config, f"LLM config must be provided" + # TODO: This should not happen here, we need to have clear separation between create/add blocks + for block in memory.get_blocks(): + self.server.block_manager.create_or_update_block(block, actor=self.user) + + # Also get any existing block_ids passed in + block_ids = block_ids or [] + # create agent + # Create the base parameters + create_params = { + "description": description, + "metadata_": metadata, + "memory_blocks": [], + "block_ids": [b.id for b in memory.get_blocks()] + block_ids, + "tool_ids": tool_ids, + "tool_rules": tool_rules, + "system": system, + "agent_type": agent_type, + "llm_config": llm_config if llm_config else self._default_llm_config, + "embedding_config": embedding_config if embedding_config else self._default_embedding_config, + "initial_message_sequence": initial_message_sequence, + "tags": tags, + } + + # Only add name if it's not None + if name is not None: + create_params["name"] = name + agent_state = self.server.create_agent( - CreateAgent( - name=name, - description=description, - metadata_=metadata, - # memory=memory, - memory_blocks=[], - # memory_blocks = memory.get_blocks(), - # memory_tools=memory_tools, - tools=tool_names, - tool_rules=tool_rules, - system=system, - agent_type=agent_type, - llm_config=llm_config if llm_config else self._default_llm_config, - embedding_config=embedding_config if embedding_config else self._default_embedding_config, - initial_message_sequence=initial_message_sequence, - tags=tags, - ), + CreateAgent(**create_params), actor=self.user, ) - # TODO: remove when we fully migrate to block creation CreateAgent model - # Link additional blocks to the agent (block ids created on the client) - # This needs to happen since the create agent does not allow passing in blocks which have already been persisted and have an ID - # So we create the agent and then link the blocks afterwards - user = self.server.get_user_or_default(self.user_id) - for block in memory.get_blocks(): - self.server.block_manager.create_or_update_block(block, actor=user) - self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_state.id, block_id=block.id) - # TODO: get full agent state - return self.server.get_agent(agent_state.id) + return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user) def update_message( self, @@ -2202,6 +2196,7 @@ class LocalClient(AbstractClient): tool_calls=tool_calls, tool_call_id=tool_call_id, ), + actor=self.user, ) return message @@ -2211,7 +2206,7 @@ class LocalClient(AbstractClient): name: Optional[str] = None, description: Optional[str] = None, system: Optional[str] = None, - tools: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict] = None, llm_config: Optional[LLMConfig] = None, @@ -2239,11 +2234,11 @@ class LocalClient(AbstractClient): # TODO: add the abilitty to reset linked block_ids self.interface.clear() agent_state = self.server.update_agent( - UpdateAgentState( - id=agent_id, + agent_id, + UpdateAgent( name=name, system=system, - tool_names=tools, + tool_ids=tool_ids, tags=tags, description=description, metadata_=metadata, @@ -2315,7 +2310,7 @@ class LocalClient(AbstractClient): Args: agent_id (str): ID of the agent to delete """ - self.server.delete_agent(user_id=self.user_id, agent_id=agent_id) + self.server.agent_manager.delete_agent(agent_id=agent_id, actor=self.user) def get_agent_by_name(self, agent_name: str) -> AgentState: """ @@ -2328,7 +2323,7 @@ class LocalClient(AbstractClient): agent_state (AgentState): State of the agent """ self.interface.clear() - return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None) + return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user) def get_agent(self, agent_id: str) -> AgentState: """ @@ -2340,9 +2335,8 @@ class LocalClient(AbstractClient): Returns: agent_state (AgentState): State representation of the agent """ - # TODO: include agent_name self.interface.clear() - return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id) + return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) def get_agent_id(self, agent_name: str) -> Optional[str]: """ @@ -2357,7 +2351,12 @@ class LocalClient(AbstractClient): self.interface.clear() assert agent_name, f"Agent name must be provided" - return self.server.get_agent_id(name=agent_name, user_id=self.user_id) + + # TODO: Refactor this futher to not have downstream users expect Optionals - this should just error + try: + return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user).id + except NoResultFound: + return None # memory def get_in_context_memory(self, agent_id: str) -> Memory: @@ -2370,7 +2369,7 @@ class LocalClient(AbstractClient): Returns: memory (Memory): In-context memory of the agent """ - memory = self.server.get_agent_memory(agent_id=agent_id) + memory = self.server.get_agent_memory(agent_id=agent_id, actor=self.user) return memory def get_core_memory(self, agent_id: str) -> Memory: @@ -2388,7 +2387,7 @@ class LocalClient(AbstractClient): """ # TODO: implement this (not sure what it should look like) - memory = self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, label=section, value=value) + memory = self.server.update_agent_core_memory(agent_id=agent_id, label=section, value=value, actor=self.user) return memory def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: @@ -2402,7 +2401,7 @@ class LocalClient(AbstractClient): summary (ArchivalMemorySummary): Summary of the archival memory """ - return self.server.get_archival_memory_summary(agent_id=agent_id) + return self.server.get_archival_memory_summary(agent_id=agent_id, actor=self.user) def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary: """ @@ -2414,7 +2413,7 @@ class LocalClient(AbstractClient): Returns: summary (RecallMemorySummary): Summary of the recall memory """ - return self.server.get_recall_memory_summary(agent_id=agent_id) + return self.server.get_recall_memory_summary(agent_id=agent_id, actor=self.user) def get_in_context_messages(self, agent_id: str) -> List[Message]: """ @@ -2426,7 +2425,7 @@ class LocalClient(AbstractClient): Returns: messages (List[Message]): List of in-context messages """ - return self.server.get_in_context_messages(agent_id=agent_id) + return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user) # agent interactions @@ -2446,11 +2445,7 @@ class LocalClient(AbstractClient): response (LettaResponse): Response from the agent """ self.interface.clear() - usage = self.server.send_messages(user_id=self.user_id, agent_id=agent_id, messages=messages) - - # auto-save - if self.auto_save: - self.save() + usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages) # format messages return LettaResponse(messages=messages, usage=usage) @@ -2490,15 +2485,11 @@ class LocalClient(AbstractClient): self.interface.clear() usage = self.server.send_messages( - user_id=self.user_id, + actor=self.user, agent_id=agent_id, messages=[MessageCreate(role=MessageRole(role), text=message, name=name)], ) - # auto-save - if self.auto_save: - self.save() - ## TODO: need to make sure date/timestamp is propely passed ## TODO: update self.interface.to_list() to return actual Message objects ## here, the message objects will have faulty created_by timestamps @@ -2547,16 +2538,9 @@ class LocalClient(AbstractClient): self.interface.clear() usage = self.server.run_command(user_id=self.user_id, agent_id=agent_id, command=command) - # auto-save - if self.auto_save: - self.save() - # NOTE: messages/usage may be empty, depending on the command return LettaResponse(messages=self.interface.to_list(), usage=usage) - def save(self): - self.server.save_agents() - # archival memory # humans / personas @@ -3036,7 +3020,7 @@ class LocalClient(AbstractClient): Returns: sources (List[Source]): List of sources """ - return self.server.list_attached_sources(agent_id=agent_id) + return self.server.agent_manager.list_attached_sources(agent_id=agent_id, actor=self.user) def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]: """ @@ -3080,7 +3064,7 @@ class LocalClient(AbstractClient): Returns: passages (List[Passage]): List of inserted passages """ - return self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory) + return self.server.insert_archival_memory(agent_id=agent_id, memory_contents=memory, actor=self.user) def delete_archival_memory(self, agent_id: str, memory_id: str): """ @@ -3090,7 +3074,7 @@ class LocalClient(AbstractClient): agent_id (str): ID of the agent memory_id (str): ID of the memory """ - self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id) + self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user) def get_archival_memory( self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000 @@ -3349,8 +3333,8 @@ class LocalClient(AbstractClient): block_req = Block(**create_block.model_dump()) block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req) # Link the block to the agent - updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id) - return updated_memory + agent = self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=self.user) + return agent.memory def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory: """ @@ -3363,7 +3347,7 @@ class LocalClient(AbstractClient): Returns: memory (Memory): The updated memory """ - return self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block_id) + return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user) def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: """ @@ -3376,7 +3360,7 @@ class LocalClient(AbstractClient): Returns: memory (Memory): The updated memory """ - return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label) + return self.server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=self.user) def get_agent_memory_blocks(self, agent_id: str) -> List[Block]: """ @@ -3388,8 +3372,8 @@ class LocalClient(AbstractClient): Returns: blocks (List[Block]): The blocks in the agent's core memory """ - block_ids = self.server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id) - return [self.server.block_manager.get_block_by_id(block_id, actor=self.user) for block_id in block_ids] + agent = self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) + return agent.memory.blocks def get_agent_memory_block(self, agent_id: str, label: str) -> Block: """ @@ -3402,8 +3386,7 @@ class LocalClient(AbstractClient): Returns: block (Block): The block corresponding to the label """ - block_id = self.server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=label) - return self.server.block_manager.get_block_by_id(block_id, actor=self.user) + return self.server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=self.user) def update_agent_memory_block( self, diff --git a/letta/config.py b/letta/config.py index 70b6bf38..ed9e8668 100644 --- a/letta/config.py +++ b/letta/config.py @@ -1,12 +1,9 @@ import configparser -import inspect -import json import os from dataclasses import dataclass from typing import Optional import letta -import letta.utils as utils from letta.constants import ( CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, @@ -16,7 +13,6 @@ from letta.constants import ( LETTA_DIR, ) from letta.log import get_logger -from letta.schemas.agent import PersistedAgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig @@ -312,160 +308,3 @@ class LettaConfig: for folder in folders: if not os.path.exists(os.path.join(LETTA_DIR, folder)): os.makedirs(os.path.join(LETTA_DIR, folder)) - - -@dataclass -class AgentConfig: - """ - - NOTE: this is a deprecated class, use AgentState instead. This class is only used for backcompatibility. - Configuration for a specific instance of an agent - """ - - def __init__( - self, - persona, - human, - # model info - model=None, - model_endpoint_type=None, - model_endpoint=None, - model_wrapper=None, - context_window=None, - # embedding info - embedding_endpoint_type=None, - embedding_endpoint=None, - embedding_model=None, - embedding_dim=None, - embedding_chunk_size=None, - # other - preset=None, - data_sources=None, - # agent info - agent_config_path=None, - name=None, - create_time=None, - letta_version=None, - # functions - functions=None, # schema definitions ONLY (linked at runtime) - ): - - assert name, f"Agent name must be provided" - self.name = name - - config = LettaConfig.load() # get default values - self.persona = config.persona if persona is None else persona - self.human = config.human if human is None else human - self.preset = config.preset if preset is None else preset - self.context_window = config.default_llm_config.context_window if context_window is None else context_window - self.model = config.default_llm_config.model if model is None else model - self.model_endpoint_type = config.default_llm_config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type - self.model_endpoint = config.default_llm_config.model_endpoint if model_endpoint is None else model_endpoint - self.model_wrapper = config.default_llm_config.model_wrapper if model_wrapper is None else model_wrapper - self.llm_config = LLMConfig( - model=self.model, - model_endpoint_type=self.model_endpoint_type, - model_endpoint=self.model_endpoint, - model_wrapper=self.model_wrapper, - context_window=self.context_window, - ) - self.embedding_endpoint_type = ( - config.default_embedding_config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type - ) - self.embedding_endpoint = config.default_embedding_config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint - self.embedding_model = config.default_embedding_config.embedding_model if embedding_model is None else embedding_model - self.embedding_dim = config.default_embedding_config.embedding_dim if embedding_dim is None else embedding_dim - self.embedding_chunk_size = ( - config.default_embedding_config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size - ) - self.embedding_config = EmbeddingConfig( - embedding_endpoint_type=self.embedding_endpoint_type, - embedding_endpoint=self.embedding_endpoint, - embedding_model=self.embedding_model, - embedding_dim=self.embedding_dim, - embedding_chunk_size=self.embedding_chunk_size, - ) - - # agent metadata - self.data_sources = data_sources if data_sources is not None else [] - self.create_time = create_time if create_time is not None else utils.get_local_time() - if letta_version is None: - import letta - - self.letta_version = letta.__version__ - else: - self.letta_version = letta_version - - # functions - self.functions = functions - - # save agent config - self.agent_config_path = ( - os.path.join(LETTA_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path - ) - - def attach_data_source(self, data_source: str): - # TODO: add warning that only once source can be attached - # i.e. previous source will be overriden - self.data_sources.append(data_source) - self.save() - - def save_dir(self): - return os.path.join(LETTA_DIR, "agents", self.name) - - def save_state_dir(self): - # directory to save agent state - return os.path.join(LETTA_DIR, "agents", self.name, "agent_state") - - def save_persistence_manager_dir(self): - # directory to save persistent manager state - return os.path.join(LETTA_DIR, "agents", self.name, "persistence_manager") - - def save_agent_index_dir(self): - # save llama index inside of persistent manager directory - return os.path.join(self.save_persistence_manager_dir(), "index") - - def save(self): - # save state of persistence manager - os.makedirs(os.path.join(LETTA_DIR, "agents", self.name), exist_ok=True) - # save version - self.letta_version = letta.__version__ - with open(self.agent_config_path, "w", encoding="utf-8") as f: - json.dump(vars(self), f, indent=4) - - def to_agent_state(self): - return PersistedAgentState( - name=self.name, - preset=self.preset, - persona=self.persona, - human=self.human, - llm_config=self.llm_config, - embedding_config=self.embedding_config, - create_time=self.create_time, - ) - - @staticmethod - def exists(name: str): - """Check if agent config exists""" - agent_config_path = os.path.join(LETTA_DIR, "agents", name) - return os.path.exists(agent_config_path) - - @classmethod - def load(cls, name: str): - """Load agent config from JSON file""" - agent_config_path = os.path.join(LETTA_DIR, "agents", name, "config.json") - assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}" - with open(agent_config_path, "r", encoding="utf-8") as f: - agent_config = json.load(f) - # allow compatibility accross versions - try: - class_args = inspect.getargspec(cls.__init__).args - except AttributeError: - # https://github.com/pytorch/pytorch/issues/15344 - class_args = inspect.getfullargspec(cls.__init__).args - agent_fields = list(agent_config.keys()) - for key in agent_fields: - if key not in class_args: - utils.printd(f"Removing missing argument {key} from agent config") - del agent_config[key] - return cls(**agent_config) diff --git a/letta/main.py b/letta/main.py index 6a394fcf..bb5d1431 100644 --- a/letta/main.py +++ b/letta/main.py @@ -130,11 +130,11 @@ def run_agent_loop( # updated agent save functions if user_input.lower() == "/exit": # letta_agent.save() - agent.save_agent(letta_agent, ms) + agent.save_agent(letta_agent) break elif user_input.lower() == "/save" or user_input.lower() == "/savechat": # letta_agent.save() - agent.save_agent(letta_agent, ms) + agent.save_agent(letta_agent) continue elif user_input.lower() == "/attach": # TODO: check if agent already has it @@ -394,7 +394,7 @@ def run_agent_loop( token_warning = step_response.in_context_memory_warning step_response.usage - agent.save_agent(letta_agent, ms) + agent.save_agent(letta_agent) skip_next_user_input = False if token_warning: user_message = system.get_token_limit_warning() diff --git a/letta/memory.py b/letta/memory.py index 081c59d5..10799094 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -1,23 +1,13 @@ -import datetime -from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC -from letta.embeddings import embedding_model, parse_and_chunk_text, query_embedding from letta.llm_api.llm_api_tools import create from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.memory import Memory from letta.schemas.message import Message -from letta.schemas.passage import Passage -from letta.utils import ( - count_tokens, - extract_date_from_timestamp, - get_local_time, - printd, - validate_date_format, -) +from letta.utils import count_tokens, printd def get_memory_functions(cls: Memory) -> Dict[str, Callable]: @@ -67,7 +57,6 @@ def summarize_messages( + message_sequence_to_summarize[cutoff:] ) - agent_state.user_id dummy_agent_id = agent_state.id message_sequence = [] message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt)) @@ -79,7 +68,7 @@ def summarize_messages( llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False response = create( llm_config=llm_config_no_inner_thoughts, - user_id=agent_state.user_id, + user_id=agent_state.created_by_id, messages=message_sequence, stream=False, ) diff --git a/letta/metadata.py b/letta/metadata.py index 017e546e..0ecd696b 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -2,23 +2,18 @@ import os import secrets -from typing import List, Optional, Union +from typing import List, Optional -from sqlalchemy import JSON, Column, DateTime, Index, String, TypeDecorator -from sqlalchemy.sql import func +from sqlalchemy import JSON, Column, Index, String, TypeDecorator from letta.config import LettaConfig from letta.orm.base import Base -from letta.schemas.agent import PersistedAgentState from letta.schemas.api_key import APIKey from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ToolRuleType from letta.schemas.llm_config import LLMConfig -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from letta.schemas.user import User -from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.settings import settings -from letta.utils import enforce_types, printd +from letta.utils import enforce_types class LLMConfigColumn(TypeDecorator): @@ -65,18 +60,6 @@ class EmbeddingConfigColumn(TypeDecorator): return value -# TODO: eventually store providers? -# class Provider(Base): -# __tablename__ = "providers" -# __table_args__ = {"extend_existing": True} -# -# id = Column(String, primary_key=True) -# name = Column(String, nullable=False) -# created_at = Column(DateTime(timezone=True)) -# api_key = Column(String, nullable=False) -# base_url = Column(String, nullable=False) - - class APIKeyModel(Base): """Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens).""" @@ -113,115 +96,6 @@ def generate_api_key(prefix="sk-", length=51) -> str: return new_key -class ToolRulesColumn(TypeDecorator): - """Custom type for storing a list of ToolRules as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - """Convert a list of ToolRules to JSON-serializable format.""" - if value: - data = [rule.model_dump() for rule in value] - for d in data: - d["type"] = d["type"].value - - for d in data: - assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field" - return data - return value - - def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]: - """Convert JSON back to a list of ToolRules.""" - if value: - return [self.deserialize_tool_rule(rule_data) for rule_data in value] - return value - - @staticmethod - def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]: - """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" - rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var - if rule_type == ToolRuleType.run_first: - return InitToolRule(**data) - elif rule_type == ToolRuleType.exit_loop: - return TerminalToolRule(**data) - elif rule_type == ToolRuleType.constrain_child_tools: - rule = ChildToolRule(**data) - return rule - else: - raise ValueError(f"Unknown tool rule type: {rule_type}") - - -class AgentModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" - - __tablename__ = "agents" - __table_args__ = {"extend_existing": True} - - id = Column(String, primary_key=True) - user_id = Column(String, nullable=False) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - description = Column(String) - - # state (context compilation) - message_ids = Column(JSON) - system = Column(String) - - # configs - agent_type = Column(String) - llm_config = Column(LLMConfigColumn) - embedding_config = Column(EmbeddingConfigColumn) - - # state - metadata_ = Column(JSON) - - # tools - tool_names = Column(JSON) - tool_rules = Column(ToolRulesColumn) - - Index(__tablename__ + "_idx_user", user_id), - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> PersistedAgentState: - agent_state = PersistedAgentState( - id=self.id, - user_id=self.user_id, - name=self.name, - created_at=self.created_at, - description=self.description, - message_ids=self.message_ids, - system=self.system, - tool_names=self.tool_names, - tool_rules=self.tool_rules, - agent_type=self.agent_type, - llm_config=self.llm_config, - embedding_config=self.embedding_config, - metadata_=self.metadata_, - ) - return agent_state - - -class AgentSourceMappingModel(Base): - """Stores mapping between agent -> source""" - - __tablename__ = "agent_source_mapping" - - id = Column(String, primary_key=True) - user_id = Column(String, nullable=False) - agent_id = Column(String, nullable=False) - source_id = Column(String, nullable=False) - Index(__tablename__ + "_idx_user", user_id, agent_id, source_id), - - def __repr__(self) -> str: - return f"" - - class MetadataStore: uri: Optional[str] = None @@ -281,127 +155,3 @@ class MetadataStore: results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all() tokens = [r.to_record() for r in results] return tokens - - @enforce_types - def create_agent(self, agent: PersistedAgentState): - # insert into agent table - # make sure agent.name does not already exist for user user_id - with self.session_maker() as session: - if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0: - raise ValueError(f"Agent with name {agent.name} already exists") - fields = vars(agent) - # fields["memory"] = agent.memory.to_dict() - # if "_internal_memory" in fields: - # del fields["_internal_memory"] - # else: - # warnings.warn(f"Agent {agent.id} has no _internal_memory field") - if "tags" in fields: - del fields["tags"] - # else: - # warnings.warn(f"Agent {agent.id} has no tags field") - session.add(AgentModel(**fields)) - session.commit() - - @enforce_types - def update_agent(self, agent: PersistedAgentState): - with self.session_maker() as session: - fields = vars(agent) - # if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever - # fields["memory"] = agent.memory.to_dict() - # if "_internal_memory" in fields: - # del fields["_internal_memory"] - # else: - # warnings.warn(f"Agent {agent.id} has no _internal_memory field") - if "tags" in fields: - del fields["tags"] - # else: - # warnings.warn(f"Agent {agent.id} has no tags field") - session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields) - session.commit() - - @enforce_types - def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager): - # TODO: Remove this once Agent is on the ORM - # TODO: To prevent unbounded growth - per_agent_lock_manager.clear_lock(agent_id) - - with self.session_maker() as session: - - # delete agents - session.query(AgentModel).filter(AgentModel.id == agent_id).delete() - - # delete mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).delete() - - session.commit() - - @enforce_types - def list_agents(self, user_id: str) -> List[PersistedAgentState]: - with self.session_maker() as session: - results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() - return [r.to_record() for r in results] - - @enforce_types - def get_agent( - self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None - ) -> Optional[PersistedAgentState]: - with self.session_maker() as session: - if agent_id: - results = session.query(AgentModel).filter(AgentModel.id == agent_id).all() - else: - assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name" - results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all() - - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result - return results[0].to_record() - - # agent source metadata - @enforce_types - def attach_source(self, user_id: str, agent_id: str, source_id: str): - with self.session_maker() as session: - # TODO: remove this (is a hack) - mapping_id = f"{user_id}-{agent_id}-{source_id}" - existing = session.query(AgentSourceMappingModel).filter( - AgentSourceMappingModel.id == mapping_id - ).first() - - if existing is None: - # Only create if it doesn't exist - session.add(AgentSourceMappingModel( - id=mapping_id, - user_id=user_id, - agent_id=agent_id, - source_id=source_id - )) - session.commit() - - @enforce_types - def list_attached_source_ids(self, agent_id: str) -> List[str]: - with self.session_maker() as session: - results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() - return [r.source_id for r in results] - - @enforce_types - def list_attached_agents(self, source_id: str) -> List[str]: - with self.session_maker() as session: - results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all() - - agent_ids = [] - # make sure agent exists - for r in results: - agent = self.get_agent(agent_id=r.agent_id) - if agent: - agent_ids.append(r.agent_id) - else: - printd(f"Warning: agent {r.agent_id} does not exist but exists in mapping database. This should never happen.") - return agent_ids - - @enforce_types - def detach_source(self, agent_id: str, source_id: str): - with self.session_maker() as session: - session.query(AgentSourceMappingModel).filter( - AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id - ).delete() - session.commit() diff --git a/letta/o1_agent.py b/letta/o1_agent.py index 837005cf..eb882bfa 100644 --- a/letta/o1_agent.py +++ b/letta/o1_agent.py @@ -85,6 +85,6 @@ class O1Agent(Agent): if step_response.messages[-1].name == "send_final_message": break if ms: - save_agent(self, ms) + save_agent(self) return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) diff --git a/letta/offline_memory_agent.py b/letta/offline_memory_agent.py index 85cbb082..1e71af6c 100644 --- a/letta/offline_memory_agent.py +++ b/letta/offline_memory_agent.py @@ -130,7 +130,7 @@ class OfflineMemoryAgent(Agent): # extras first_message_verify_mono: bool = False, max_memory_rethinks: int = 10, - initial_message_sequence: Optional[List[Message]] = None, + initial_message_sequence: Optional[List[Message]] = None, ): super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence) self.first_message_verify_mono = first_message_verify_mono @@ -173,6 +173,6 @@ class OfflineMemoryAgent(Agent): self.interface.step_complete() if ms: - save_agent(self, ms) + save_agent(self) return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index b7f7bb96..ed8f2460 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,3 +1,4 @@ +from letta.orm.agent import Agent from letta.orm.agents_tags import AgentsTags from letta.orm.base import Base from letta.orm.block import Block @@ -9,6 +10,7 @@ from letta.orm.organization import Organization from letta.orm.passage import Passage from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source +from letta.orm.sources_agents import SourcesAgents from letta.orm.tool import Tool from letta.orm.tools_agents import ToolsAgents from letta.orm.user import User diff --git a/letta/orm/agent.py b/letta/orm/agent.py new file mode 100644 index 00000000..99a7e8bd --- /dev/null +++ b/letta/orm/agent.py @@ -0,0 +1,196 @@ +import uuid +from typing import TYPE_CHECKING, List, Optional, Union + +from sqlalchemy import JSON, String, TypeDecorator, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.block import Block +from letta.orm.message import Message +from letta.orm.mixins import OrganizationMixin +from letta.orm.organization import Organization +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.agent import AgentState as PydanticAgentState +from letta.schemas.agent import AgentType +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ToolRuleType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import Memory +from letta.schemas.tool_rule import ( + ChildToolRule, + InitToolRule, + TerminalToolRule, + ToolRule, +) + +if TYPE_CHECKING: + from letta.orm.agents_tags import AgentsTags + from letta.orm.organization import Organization + from letta.orm.source import Source + from letta.orm.tool import Tool + + +class LLMConfigColumn(TypeDecorator): + """Custom type for storing LLMConfig as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + # return vars(value) + if isinstance(value, LLMConfig): + return value.model_dump() + return value + + def process_result_value(self, value, dialect): + if value: + return LLMConfig(**value) + return value + + +class EmbeddingConfigColumn(TypeDecorator): + """Custom type for storing EmbeddingConfig as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + # return vars(value) + if isinstance(value, EmbeddingConfig): + return value.model_dump() + return value + + def process_result_value(self, value, dialect): + if value: + return EmbeddingConfig(**value) + return value + + +class ToolRulesColumn(TypeDecorator): + """Custom type for storing a list of ToolRules as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + """Convert a list of ToolRules to JSON-serializable format.""" + if value: + data = [rule.model_dump() for rule in value] + for d in data: + d["type"] = d["type"].value + + for d in data: + assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field" + return data + return value + + def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]: + """Convert JSON back to a list of ToolRules.""" + if value: + return [self.deserialize_tool_rule(rule_data) for rule_data in value] + return value + + @staticmethod + def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]: + """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" + rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var + if rule_type == ToolRuleType.run_first: + return InitToolRule(**data) + elif rule_type == ToolRuleType.exit_loop: + return TerminalToolRule(**data) + elif rule_type == ToolRuleType.constrain_child_tools: + rule = ChildToolRule(**data) + return rule + else: + raise ValueError(f"Unknown tool rule type: {rule_type}") + + +class Agent(SqlalchemyBase, OrganizationMixin): + __tablename__ = "agents" + __pydantic_model__ = PydanticAgentState + __table_args__ = (UniqueConstraint("organization_id", "name", name="unique_org_agent_name"),) + + # agent generates its own id + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # TODO: Move this in this PR? at the very end? + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}") + + # Descriptor fields + agent_type: Mapped[Optional[AgentType]] = mapped_column(String, nullable=True, doc="The type of Agent") + name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a human-readable identifier for an agent, non-unique.") + description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The description of the agent.") + + # System prompt + system: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The system prompt used by the agent.") + + # In context memory + # TODO: This should be a separate mapping table + # This is dangerously flexible with the JSON type + message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.") + + # Metadata and configs + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.") + llm_config: Mapped[Optional[LLMConfig]] = mapped_column( + LLMConfigColumn, nullable=True, doc="the LLM backend configuration object for this agent." + ) + embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column( + EmbeddingConfigColumn, doc="the embedding configuration object for this agent." + ) + + # Tool rules + tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") + tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True) + sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin") + core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin") + messages: Mapped[List["Message"]] = relationship( + "Message", + back_populates="agent", + lazy="selectin", + cascade="all, delete-orphan", # Ensure messages are deleted when the agent is deleted + passive_deletes=True, + ) + tags: Mapped[List["AgentsTags"]] = relationship( + "AgentsTags", + back_populates="agent", + cascade="all, delete-orphan", + lazy="selectin", + doc="Tags associated with the agent.", + ) + # passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin") + + def to_pydantic(self) -> PydanticAgentState: + """converts to the basic pydantic model counterpart""" + state = { + "id": self.id, + "name": self.name, + "description": self.description, + "message_ids": self.message_ids, + "tools": self.tools, + "sources": self.sources, + "tags": [t.tag for t in self.tags], + "tool_rules": self.tool_rules, + "system": self.system, + "agent_type": self.agent_type, + "llm_config": self.llm_config, + "embedding_config": self.embedding_config, + "metadata_": self.metadata_, + "memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]), + "created_by_id": self.created_by_id, + "last_updated_by_id": self.last_updated_by_id, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + return self.__pydantic_model__(**state) diff --git a/letta/orm/agents_tags.py b/letta/orm/agents_tags.py index 1910f528..76ff9011 100644 --- a/letta/orm/agents_tags.py +++ b/letta/orm/agents_tags.py @@ -1,28 +1,20 @@ -from typing import TYPE_CHECKING - from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.mixins import OrganizationMixin -from letta.orm.sqlalchemy_base import SqlalchemyBase -from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags - -if TYPE_CHECKING: - from letta.orm.organization import Organization +from letta.orm.base import Base -class AgentsTags(SqlalchemyBase, OrganizationMixin): - """Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering.""" - +class AgentsTags(Base): __tablename__ = "agents_tags" - __pydantic_model__ = PydanticAgentsTags __table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),) - # The agent associated with this tag - agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True) + # # agent generates its own id + # # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # # TODO: Move this in this PR? at the very end? + # id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agents_tags-{uuid.uuid4()}") - # The name of the tag - tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.") + agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) + tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the agent.", primary_key=True) - # relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags") + # Relationships + agent: Mapped["Agent"] = relationship("Agent", back_populates="tags") diff --git a/letta/orm/block.py b/letta/orm/block.py index 84bbdb7e..99cfa29b 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -1,16 +1,17 @@ from typing import TYPE_CHECKING, Optional, Type -from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event +from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT +from letta.orm.blocks_agents import BlocksAgents from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import Human, Persona if TYPE_CHECKING: - from letta.orm import BlocksAgents, Organization + from letta.orm import Organization class Block(OrganizationMixin, SqlalchemyBase): @@ -35,7 +36,6 @@ class Block(OrganizationMixin, SqlalchemyBase): # relationships organization: Mapped[Optional["Organization"]] = relationship("Organization") - blocks_agents: Mapped[list["BlocksAgents"]] = relationship("BlocksAgents", back_populates="block", cascade="all, delete") def to_pydantic(self) -> Type: match self.label: @@ -46,3 +46,28 @@ class Block(OrganizationMixin, SqlalchemyBase): case _: Schema = PydanticBlock return Schema.model_validate(self) + + +@event.listens_for(Block, "after_update") # Changed from 'before_update' +def block_before_update(mapper, connection, target): + """Handle updating BlocksAgents when a block's label changes.""" + label_history = attributes.get_history(target, "label") + if not label_history.has_changes(): + return + + blocks_agents = BlocksAgents.__table__ + connection.execute( + blocks_agents.update() + .where(blocks_agents.c.block_id == target.id, blocks_agents.c.block_label == label_history.deleted[0]) + .values(block_label=label_history.added[0]) + ) + + +@event.listens_for(Block, "before_insert") +@event.listens_for(Block, "before_update") +def validate_value_length(mapper, connection, target): + """Ensure the value length does not exceed the limit.""" + if target.value and len(target.value) > target.limit: + raise ValueError( + f"Value length ({len(target.value)}) exceeds the limit ({target.limit}) for block with label '{target.label}' and id '{target.id}'." + ) diff --git a/letta/orm/blocks_agents.py b/letta/orm/blocks_agents.py index a3449646..4774783b 100644 --- a/letta/orm/blocks_agents.py +++ b/letta/orm/blocks_agents.py @@ -1,15 +1,13 @@ from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column -from letta.orm.sqlalchemy_base import SqlalchemyBase -from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents +from letta.orm.base import Base -class BlocksAgents(SqlalchemyBase): +class BlocksAgents(Base): """Agents must have one or many blocks to make up their core memory.""" __tablename__ = "blocks_agents" - __pydantic_model__ = PydanticBlocksAgents __table_args__ = ( UniqueConstraint( "agent_id", @@ -17,16 +15,12 @@ class BlocksAgents(SqlalchemyBase): name="unique_label_per_agent", ), ForeignKeyConstraint( - ["block_id", "block_label"], - ["block.id", "block.label"], - name="fk_block_id_label", + ["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label", deferrable=True, initially="DEFERRED" ), + UniqueConstraint("agent_id", "block_id", name="unique_agent_block"), ) # unique agent + block label agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) block_id: Mapped[str] = mapped_column(String, primary_key=True) block_label: Mapped[str] = mapped_column(String, primary_key=True) - - # relationships - block: Mapped["Block"] = relationship("Block", back_populates="blocks_agents") diff --git a/letta/orm/message.py b/letta/orm/message.py index 8de6f1f5..77ac075a 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -59,6 +59,5 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call") # Relationships - # TODO: Add in after Agent ORM is created - # agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") + agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin") diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 8dc56e16..bed2b00f 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization if TYPE_CHECKING: + from letta.orm.agent import Agent from letta.orm.file import FileMetadata from letta.orm.tool import Tool from letta.orm.user import User @@ -25,7 +26,6 @@ class Organization(SqlalchemyBase): tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan") sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") - agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan") files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan") sandbox_configs: Mapped[List["SandboxConfig"]] = relationship( "SandboxConfig", back_populates="organization", cascade="all, delete-orphan" @@ -36,10 +36,5 @@ class Organization(SqlalchemyBase): # relationships messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan") + agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan") - - # TODO: Map these relationships later when we actually make these models - # below is just a suggestion - # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") - # tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") - # documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/passage.py b/letta/orm/passage.py index bfa3e153..b91eb434 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -1,19 +1,18 @@ +import base64 from datetime import datetime -from typing import Optional, TYPE_CHECKING -from sqlalchemy import Column, String, DateTime, JSON, ForeignKey -from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy.types import TypeDecorator, BINARY +from typing import TYPE_CHECKING, Optional import numpy as np -import base64 - -from letta.orm.source import EmbeddingConfigColumn -from letta.orm.sqlalchemy_base import SqlalchemyBase -from letta.orm.mixins import FileMixin, OrganizationMixin -from letta.schemas.passage import Passage as PydanticPassage +from sqlalchemy import JSON, Column, DateTime, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.types import BINARY, TypeDecorator from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM +from letta.orm.mixins import FileMixin, OrganizationMixin +from letta.orm.source import EmbeddingConfigColumn +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.passage import Passage as PydanticPassage from letta.settings import settings config = LettaConfig() @@ -21,8 +20,10 @@ config = LettaConfig() if TYPE_CHECKING: from letta.orm.organization import Organization + class CommonVector(TypeDecorator): """Common type for representing vectors in SQLite""" + impl = BINARY cache_ok = True @@ -43,10 +44,12 @@ class CommonVector(TypeDecorator): value = base64.b64decode(value) return np.frombuffer(value, dtype=np.float32) -# TODO: After migration to Passage, will need to manually delete passages where files + +# TODO: After migration to Passage, will need to manually delete passages where files # are deleted on web class Passage(SqlalchemyBase, OrganizationMixin, FileMixin): """Defines data model for storing Passages""" + __tablename__ = "passages" __table_args__ = {"extend_existing": True} __pydantic_model__ = PydanticPassage @@ -59,6 +62,7 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin): created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow) if settings.letta_pg_uri_no_default: from pgvector.sqlalchemy import Vector + embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) else: embedding = Column(CommonVector) diff --git a/letta/orm/source.py b/letta/orm/source.py index 4b2262f0..e849cddb 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -48,4 +48,4 @@ class Source(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") - # agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") + agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") diff --git a/letta/orm/sources_agents.py b/letta/orm/sources_agents.py new file mode 100644 index 00000000..cf502e71 --- /dev/null +++ b/letta/orm/sources_agents.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.base import Base + + +class SourcesAgents(Base): + """Agents can have zero to many sources""" + + __tablename__ = "sources_agents" + + agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) + source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 74d3f3be..d13e85b1 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,7 +1,6 @@ from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING, List, Literal, Optional, Type -import sqlite3 +from typing import TYPE_CHECKING, List, Literal, Optional from sqlalchemy import String, desc, func, or_, select from sqlalchemy.exc import DBAPIError @@ -9,12 +8,12 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins -from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance from letta.orm.errors import ( ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError, ) +from letta.orm.sqlite_functions import adapt_array if TYPE_CHECKING: from pydantic import BaseModel @@ -64,11 +63,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text: Optional[str] = None, query_embedding: Optional[List[float]] = None, ascending: bool = True, + tags: Optional[List[str]] = None, + match_all_tags: bool = False, **kwargs, - ) -> List[Type["SqlalchemyBase"]]: + ) -> List["SqlalchemyBase"]: """ List records with cursor-based pagination, ordering by created_at. Cursor is an ID, but pagination is based on the cursor object's created_at value. + + Args: + db_session: SQLAlchemy session + cursor: ID of the last item seen (for pagination) + start_date: Filter items after this date + end_date: Filter items before this date + limit: Maximum number of items to return + query_text: Text to search for + query_embedding: Vector to search for similar embeddings + ascending: Sort direction + tags: List of tags to filter by + match_all_tags: If True, return items matching all tags. If False, match any tag. + **kwargs: Additional filters to apply """ if start_date and end_date and start_date > end_date: raise ValueError("start_date must be earlier than or equal to end_date") @@ -84,7 +98,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query = select(cls) - # Apply filtering logic + # Handle tag filtering if the model has tags + if tags and hasattr(cls, "tags"): + query = select(cls) + + if match_all_tags: + # Match ALL tags - use subqueries + for tag in tags: + subquery = select(cls.tags.property.mapper.class_.agent_id).where(cls.tags.property.mapper.class_.tag == tag) + query = query.filter(cls.id.in_(subquery)) + else: + # Match ANY tag - use join and filter + query = ( + query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id) # Deduplicate results + ) + + # Group by primary key and all necessary columns to avoid JSON comparison + query = query.group_by(cls.id) + + # Apply filtering logic from kwargs for key, value in kwargs.items(): column = getattr(cls, key) if isinstance(value, (list, tuple, set)): @@ -98,9 +130,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if end_date: query = query.filter(cls.created_at < end_date) - # Cursor-based pagination using created_at - # TODO: There is a really nasty race condition issue here with Sqlite - # TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason + # Cursor-based pagination if cursor_obj: if ascending: query = query.where(cls.created_at >= cursor_obj.created_at).where( @@ -111,40 +141,34 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id) ) - # Apply text search + # Text search if query_text: - from sqlalchemy import func query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) - # Apply embedding search (Passages) + # Embedding search (for Passages) is_ordered = False if query_embedding: - # check if embedding column exists. should only exist for passages if not hasattr(cls, "embedding"): raise ValueError(f"Class {cls.__name__} does not have an embedding column") - + from letta.settings import settings + if settings.letta_pg_uri_no_default: # PostgreSQL with pgvector - from pgvector.sqlalchemy import Vector query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc()) else: # SQLite with custom vector type - from sqlalchemy import func - query_embedding_binary = adapt_array(query_embedding) query = query.order_by( - func.cosine_distance(cls.embedding, query_embedding_binary).asc(), - cls.created_at.asc(), - cls.id.asc() + func.cosine_distance(cls.embedding, query_embedding_binary).asc(), cls.created_at.asc(), cls.id.asc() ) is_ordered = True - # Handle ordering and soft deletes + # Handle soft deletes if hasattr(cls, "is_deleted"): query = query.where(cls.is_deleted == False) - - # Apply ordering by created_at + + # Apply ordering if not is_ordered: if ascending: query = query.order_by(cls.created_at, cls.id) @@ -164,7 +188,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], access_type: AccessType = AccessType.ORGANIZATION, **kwargs, - ) -> Type["SqlalchemyBase"]: + ) -> "SqlalchemyBase": """The primary accessor for an ORM record. Args: db_session: the database session to use when retrieving the record @@ -207,7 +231,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions" raise NoResultFound(f"{cls.__name__} not found with {conditions_str}") - def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: + def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: @@ -221,7 +245,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): except DBAPIError as e: self._handle_dbapi_error(e) - def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: + def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: @@ -245,7 +269,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): else: logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") - def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: + def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: self._set_created_and_updated_by_fields(actor.id) @@ -388,14 +412,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): raise @property - def __pydantic_model__(self) -> Type["BaseModel"]: + def __pydantic_model__(self) -> "BaseModel": raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.") - def to_pydantic(self) -> Type["BaseModel"]: + def to_pydantic(self) -> "BaseModel": """converts to the basic pydantic model counterpart""" return self.__pydantic_model__.model_validate(self) - def to_record(self) -> Type["BaseModel"]: + def to_record(self) -> "BaseModel": """Deprecated accessor for to_pydantic""" logger.warning("to_record is deprecated, use to_pydantic instead.") - return self.to_pydantic() \ No newline at end of file + return self.to_pydantic() diff --git a/letta/orm/tool.py b/letta/orm/tool.py index 8f1ac46a..a25c7ebb 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, String, UniqueConstraint, event +from sqlalchemy import JSON, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship # TODO everything in functions should live in this model @@ -11,7 +11,6 @@ from letta.schemas.tool import Tool as PydanticTool if TYPE_CHECKING: from letta.orm.organization import Organization - from letta.orm.tools_agents import ToolsAgents class Tool(SqlalchemyBase, OrganizationMixin): @@ -42,20 +41,3 @@ class Tool(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") - tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan") - - -# Add event listener to update tool_name in ToolsAgents when Tool name changes -@event.listens_for(Tool, "before_update") -def update_tool_name_in_tools_agents(mapper, connection, target): - """Update tool_name in ToolsAgents when Tool name changes.""" - state = target._sa_instance_state - history = state.get_history("name", passive=True) - if not history.has_changes(): - return - - # Get the new name and update all associated ToolsAgents records - new_name = target.name - from letta.orm.tools_agents import ToolsAgents - - connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name)) diff --git a/letta/orm/tools_agents.py b/letta/orm/tools_agents.py index dfb8a9a7..52c1e0a1 100644 --- a/letta/orm/tools_agents.py +++ b/letta/orm/tools_agents.py @@ -1,32 +1,15 @@ -from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column -from letta.orm.sqlalchemy_base import SqlalchemyBase -from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents +from letta.orm import Base -class ToolsAgents(SqlalchemyBase): +class ToolsAgents(Base): """Agents can have one or many tools associated with them.""" __tablename__ = "tools_agents" - __pydantic_model__ = PydanticToolsAgents - __table_args__ = ( - UniqueConstraint( - "agent_id", - "tool_name", - name="unique_tool_per_agent", - ), - ForeignKeyConstraint( - ["tool_id"], - ["tools.id"], - name="fk_tool_id", - ), - ) + __table_args__ = (UniqueConstraint("agent_id", "tool_id", name="unique_agent_tool"),) # Each agent must have unique tool names - agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) - tool_id: Mapped[str] = mapped_column(String, primary_key=True) - tool_name: Mapped[str] = mapped_column(String, primary_key=True) - - # relationships - tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents") + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True) + tool_id: Mapped[str] = mapped_column(String, ForeignKey("tools.id", ondelete="CASCADE"), primary_key=True) diff --git a/letta/orm/user.py b/letta/orm/user.py index 62a3c0e6..9f626b10 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -20,10 +20,9 @@ class User(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="users") - jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan") + jobs: Mapped[List["Job"]] = relationship( + "Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan" + ) # TODO: Add this back later potentially - # agents: Mapped[List["Agent"]] = relationship( - # "Agent", secondary="users_agents", back_populates="users", doc="the agents associated with this user." - # ) # tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.") diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index f0372f8f..994233ab 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -1,13 +1,11 @@ -from datetime import datetime from enum import Enum from typing import Dict, List, Optional from pydantic import BaseModel, Field, field_validator -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.letta_base import LettaBase +from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.message import Message, MessageCreate @@ -15,15 +13,15 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.tool_rule import ToolRule +from letta.utils import create_random_username -class BaseAgent(LettaBase, validate_assignment=True): +class BaseAgent(OrmMetadataBase, validate_assignment=True): __id_prefix__ = "agent" description: Optional[str] = Field(None, description="The description of the agent.") # metadata metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") - user_id: Optional[str] = Field(None, description="The user id of the agent.") class AgentType(str, Enum): @@ -38,37 +36,7 @@ class AgentType(str, Enum): chat_only_agent = "chat_only_agent" -class PersistedAgentState(BaseAgent, validate_assignment=True): - # NOTE: this has been changed to represent the data stored in the ORM, NOT what is passed around internally or returned to the user - id: str = BaseAgent.generate_id_field() - name: str = Field(..., description="The name of the agent.") - created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now) - - # in-context memory - message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.") - # tools - # TODO: move to ORM mapping - tool_names: List[str] = Field(..., description="The tools used by the agent.") - - # tool rules - tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.") - - # system prompt - system: str = Field(..., description="The system prompt used by the agent.") - - # agent configuration - agent_type: AgentType = Field(..., description="The type of agent.") - - # llm information - llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.") - embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.") - - class Config: - arbitrary_types_allowed = True - validate_assignment = True - - -class AgentState(PersistedAgentState): +class AgentState(BaseAgent): """ Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent. @@ -86,42 +54,53 @@ class AgentState(PersistedAgentState): """ # NOTE: this is what is returned to the client and also what is used to initialize `Agent` + id: str = BaseAgent.generate_id_field() + name: str = Field(..., description="The name of the agent.") + # tool rules + tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.") + + # in-context memory + message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.") + + # system prompt + system: str = Field(..., description="The system prompt used by the agent.") + + # agent configuration + agent_type: AgentType = Field(..., description="The type of agent.") + + # llm information + llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.") + embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.") # This is an object representing the in-process state of a running `Agent` # Field in this object can be theoretically edited by tools, and will be persisted by the ORM + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.") + memory: Memory = Field(..., description="The in-context memory of the agent.") tools: List[Tool] = Field(..., description="The tools used by the agent.") sources: List[Source] = Field(..., description="The sources used by the agent.") tags: List[str] = Field(..., description="The tags associated with the agent.") # TODO: add in context message objects - def to_persisted_agent_state(self) -> PersistedAgentState: - # turn back into persisted agent - data = self.model_dump() - del data["memory"] - del data["tools"] - del data["sources"] - del data["tags"] - return PersistedAgentState(**data) - class CreateAgent(BaseAgent): # # all optional as server can generate defaults - name: Optional[str] = Field(None, description="The name of the agent.") - message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") + name: str = Field(default_factory=lambda: create_random_username(), description="The name of the agent.") # memory creation memory_blocks: List[CreateBlock] = Field( - # [CreateHuman(), CreatePersona()], description="The blocks to create in the agent's in-context memory." ..., description="The blocks to create in the agent's in-context memory.", ) - - tools: List[str] = Field(BASE_TOOLS + BASE_MEMORY_TOOLS, description="The tools used by the agent.") + # TODO: This is a legacy field and should be removed ASAP to force `tool_ids` usage + tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") + tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.") + source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.") + block_ids: Optional[List[str]] = Field(None, description="The ids of the blocks used by the agent.") tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.") tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") - agent_type: AgentType = Field(AgentType.memgpt_agent, description="The type of agent.") + agent_type: AgentType = Field(default_factory=lambda: AgentType.memgpt_agent, description="The type of agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") # Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence @@ -129,6 +108,7 @@ class CreateAgent(BaseAgent): # initial_message_sequence: Optional[List[MessageCreate]] = Field( None, description="The initial set of messages to put in the agent's in-context memory." ) + include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.") @field_validator("name") @classmethod @@ -156,18 +136,21 @@ class CreateAgent(BaseAgent): # return name -class UpdateAgentState(BaseAgent): - id: str = Field(..., description="The id of the agent.") +class UpdateAgent(BaseAgent): name: Optional[str] = Field(None, description="The name of the agent.") - tool_names: Optional[List[str]] = Field(None, description="The tools used by the agent.") + tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.") + source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.") + block_ids: Optional[List[str]] = Field(None, description="The ids of the blocks used by the agent.") tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") + tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") - - # TODO: determine if these should be editable via this schema? message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") + class Config: + extra = "ignore" # Ignores extra fields + class AgentStepResponse(BaseModel): messages: List[Message] = Field(..., description="The messages generated during the agent's step.") diff --git a/letta/schemas/agents_tags.py b/letta/schemas/agents_tags.py deleted file mode 100644 index eba5e0db..00000000 --- a/letta/schemas/agents_tags.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from letta.schemas.letta_base import LettaBase - - -class AgentsTagsBase(LettaBase): - __id_prefix__ = "agents_tags" - - -class AgentsTags(AgentsTagsBase): - """ - Schema representing the relationship between tags and agents. - - Parameters: - agent_id (str): The ID of the associated agent. - tag_id (str): The ID of the associated tag. - tag_name (str): The name of the tag. - created_at (datetime): The date this relationship was created. - """ - - id: str = AgentsTagsBase.generate_id_field() - agent_id: str = Field(..., description="The ID of the associated agent.") - tag: str = Field(..., description="The name of the tag.") - created_at: Optional[datetime] = Field(None, description="The creation date of the association.") - updated_at: Optional[datetime] = Field(None, description="The update date of the tag.") - is_deleted: bool = Field(False, description="Whether this tag is deleted or not.") - - -class AgentsTagsCreate(AgentsTagsBase): - tag: str = Field(..., description="The tag name.") diff --git a/letta/schemas/blocks_agents.py b/letta/schemas/blocks_agents.py deleted file mode 100644 index 8b33925a..00000000 --- a/letta/schemas/blocks_agents.py +++ /dev/null @@ -1,32 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from letta.schemas.letta_base import LettaBase - - -class BlocksAgentsBase(LettaBase): - __id_prefix__ = "blocks_agents" - - -class BlocksAgents(BlocksAgentsBase): - """ - Schema representing the relationship between blocks and agents. - - Parameters: - agent_id (str): The ID of the associated agent. - block_id (str): The ID of the associated block. - block_label (str): The label of the block. - created_at (datetime): The date this relationship was created. - updated_at (datetime): The date this relationship was last updated. - is_deleted (bool): Whether this block-agent relationship is deleted or not. - """ - - id: str = BlocksAgentsBase.generate_id_field() - agent_id: str = Field(..., description="The ID of the associated agent.") - block_id: str = Field(..., description="The ID of the associated block.") - block_label: str = Field(..., description="The label of the block.") - created_at: Optional[datetime] = Field(None, description="The creation date of the association.") - updated_at: Optional[datetime] = Field(None, description="The update date of the association.") - is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.") diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 9084006d..797eac57 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -87,7 +87,7 @@ class Memory(BaseModel, validate_assignment=True): Template(prompt_template) # Validate compatibility with current memory structure - test_render = Template(prompt_template).render(blocks=self.blocks) + Template(prompt_template).render(blocks=self.blocks) # If we get here, the template is valid and compatible self.prompt_template = prompt_template @@ -213,6 +213,7 @@ class ChatMemory(BasicBlockMemory): human (str): The starter value for the human block. limit (int): The character limit for each block. """ + # TODO: Should these be CreateBlocks? super().__init__(blocks=[Block(value=persona, limit=limit, label="persona"), Block(value=human, limit=limit, label="human")]) diff --git a/letta/schemas/tools_agents.py b/letta/schemas/tools_agents.py deleted file mode 100644 index b7e8bdcf..00000000 --- a/letta/schemas/tools_agents.py +++ /dev/null @@ -1,32 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from letta.schemas.letta_base import LettaBase - - -class ToolsAgentsBase(LettaBase): - __id_prefix__ = "tools_agents" - - -class ToolsAgents(ToolsAgentsBase): - """ - Schema representing the relationship between tools and agents. - - Parameters: - agent_id (str): The ID of the associated agent. - tool_id (str): The ID of the associated tool. - tool_name (str): The name of the tool. - created_at (datetime): The date this relationship was created. - updated_at (datetime): The date this relationship was last updated. - is_deleted (bool): Whether this tool-agent relationship is deleted or not. - """ - - id: str = ToolsAgentsBase.generate_id_field() - agent_id: str = Field(..., description="The ID of the associated agent.") - tool_id: str = Field(..., description="The ID of the associated tool.") - tool_name: str = Field(..., description="The name of the tool.") - created_at: Optional[datetime] = Field(None, description="The creation date of the association.") - updated_at: Optional[datetime] = Field(None, description="The update date of the association.") - is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.") diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index b48a13d0..615811d7 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -25,9 +25,6 @@ from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.routers.openai.assistants.assistants import ( router as openai_assistants_router, ) -from letta.server.rest_api.routers.openai.assistants.threads import ( - router as openai_threads_router, -) from letta.server.rest_api.routers.openai.chat_completions.chat_completions import ( router as openai_chat_completions_router, ) @@ -215,7 +212,6 @@ def create_application() -> "FastAPI": # openai app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX) - app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX) app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX) # /api/auth endpoints @@ -236,7 +232,6 @@ def create_application() -> "FastAPI": @app.on_event("shutdown") def on_shutdown(): global server - server.save_agents() # server = None return app diff --git a/letta/server/rest_api/routers/openai/assistants/threads.py b/letta/server/rest_api/routers/openai/assistants/threads.py deleted file mode 100644 index 8742aa42..00000000 --- a/letta/server/rest_api/routers/openai/assistants/threads.py +++ /dev/null @@ -1,338 +0,0 @@ -import uuid -from typing import TYPE_CHECKING, List, Optional - -from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query - -from letta.constants import DEFAULT_PRESET -from letta.schemas.agent import CreateAgent -from letta.schemas.enums import MessageRole -from letta.schemas.message import Message -from letta.schemas.openai.openai import ( - MessageFile, - OpenAIMessage, - OpenAIRun, - OpenAIRunStep, - OpenAIThread, - Text, -) -from letta.server.rest_api.routers.openai.assistants.schemas import ( - CreateMessageRequest, - CreateRunRequest, - CreateThreadRequest, - CreateThreadRunRequest, - DeleteThreadResponse, - ListMessagesResponse, - ModifyMessageRequest, - ModifyRunRequest, - ModifyThreadRequest, - OpenAIThread, - SubmitToolOutputsToRunRequest, -) -from letta.server.rest_api.utils import get_letta_server -from letta.server.server import SyncServer - -if TYPE_CHECKING: - from letta.utils import get_utc_time - - -# TODO: implement mechanism for creating/authenticating users associated with a bearer token -router = APIRouter(prefix="/v1/threads", tags=["threads"]) - - -@router.post("/", response_model=OpenAIThread) -def create_thread( - request: CreateThreadRequest = Body(...), - server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - # TODO: use requests.description and requests.metadata fields - # TODO: handle requests.file_ids and requests.tools - # TODO: eventually allow request to override embedding/llm model - actor = server.get_user_or_default(user_id=user_id) - - print("Create thread/agent", request) - # create a letta agent - agent_state = server.create_agent( - request=CreateAgent(), - user_id=actor.id, - ) - # TODO: insert messages into recall memory - return OpenAIThread( - id=str(agent_state.id), - created_at=int(agent_state.created_at.timestamp()), - metadata={}, # TODO add metadata? - ) - - -@router.get("/{thread_id}", response_model=OpenAIThread) -def retrieve_thread( - thread_id: str = Path(..., description="The unique identifier of the thread."), - server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - actor = server.get_user_or_default(user_id=user_id) - agent = server.get_agent(user_id=actor.id, agent_id=thread_id) - assert agent is not None - return OpenAIThread( - id=str(agent.id), - created_at=int(agent.created_at.timestamp()), - metadata={}, # TODO add metadata? - ) - - -@router.get("/{thread_id}", response_model=OpenAIThread) -def modify_thread( - thread_id: str = Path(..., description="The unique identifier of the thread."), - request: ModifyThreadRequest = Body(...), -): - # TODO: add agent metadata so this can be modified - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.delete("/{thread_id}", response_model=DeleteThreadResponse) -def delete_thread( - thread_id: str = Path(..., description="The unique identifier of the thread."), -): - # TODO: delete agent - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage) -def create_message( - thread_id: str = Path(..., description="The unique identifier of the thread."), - request: CreateMessageRequest = Body(...), - server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - actor = server.get_user_or_default(user_id=user_id) - agent_id = thread_id - # create message object - message = Message( - user_id=actor.id, - agent_id=agent_id, - role=MessageRole(request.role), - text=request.content, - model=None, - tool_calls=None, - tool_call_id=None, - name=None, - ) - agent = server.load_agent(agent_id=agent_id) - # add message to agent - agent._append_to_messages([message]) - - openai_message = OpenAIMessage( - id=str(message.id), - created_at=int(message.created_at.timestamp()), - content=[Text(text=(message.text if message.text else ""))], - role=message.role, - thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - # TODO(sarah) fill in? - run_id=None, - file_ids=None, - metadata=None, - # file_ids=message.file_ids, - # metadata=message.metadata, - ) - return openai_message - - -@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse) -def list_messages( - thread_id: str = Path(..., description="The unique identifier of the thread."), - limit: int = Query(1000, description="How many messages to retrieve."), - order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."), - after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), - before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), - server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - actor = server.get_user_or_default(user_id) - after_uuid = after if before else None - before_uuid = before if before else None - agent_id = thread_id - reverse = True if (order == "desc") else False - json_messages = server.get_agent_recall_cursor( - user_id=actor.id, - agent_id=agent_id, - limit=limit, - after=after_uuid, - before=before_uuid, - order_by="created_at", - reverse=reverse, - ) - assert isinstance(json_messages, List) - assert all([isinstance(message, Message) for message in json_messages]) - assert isinstance(json_messages[0], Message) - print(json_messages[0].text) - # convert to openai style messages - openai_messages = [] - for message in json_messages: - assert isinstance(message, Message) - openai_messages.append( - OpenAIMessage( - id=str(message.id), - created_at=int(message.created_at.timestamp()), - content=[Text(text=(message.text if message.text else ""))], - role=str(message.role), - thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - # TODO(sarah) fill in? - run_id=None, - file_ids=None, - metadata=None, - # file_ids=message.file_ids, - # metadata=message.metadata, - ) - ) - print("MESSAGES", openai_messages) - # TODO: cast back to message objects - return ListMessagesResponse(messages=openai_messages) - - -@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage) -def retrieve_message( - thread_id: str = Path(..., description="The unique identifier of the thread."), - message_id: str = Path(..., description="The unique identifier of the message."), - server: SyncServer = Depends(get_letta_server), -): - agent_id = thread_id - message = server.get_agent_message(agent_id=agent_id, message_id=message_id) - assert message is not None - return OpenAIMessage( - id=message_id, - created_at=int(message.created_at.timestamp()), - content=[Text(text=(message.text if message.text else ""))], - role=message.role, - thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - # TODO(sarah) fill in? - run_id=None, - file_ids=None, - metadata=None, - # file_ids=message.file_ids, - # metadata=message.metadata, - ) - - -@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile) -def retrieve_message_file( - thread_id: str = Path(..., description="The unique identifier of the thread."), - message_id: str = Path(..., description="The unique identifier of the message."), - file_id: str = Path(..., description="The unique identifier of the file."), -): - # TODO: implement? - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage) -def modify_message( - thread_id: str = Path(..., description="The unique identifier of the thread."), - message_id: str = Path(..., description="The unique identifier of the message."), - request: ModifyMessageRequest = Body(...), -): - # TODO: add metada field to message so this can be modified - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun) -def create_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - request: CreateRunRequest = Body(...), - server: SyncServer = Depends(get_letta_server), -): - - # TODO: add request.instructions as a message? - agent_id = thread_id - # TODO: override preset of agent with request.assistant_id - agent = server.load_agent(agent_id=agent_id) - agent.inner_step(messages=[]) # already has messages added - run_id = str(uuid.uuid4()) - create_time = int(get_utc_time().timestamp()) - return OpenAIRun( - id=run_id, - created_at=create_time, - thread_id=str(agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - status="completed", # TODO: eventaully allow offline execution - expires_at=create_time, - model=agent.agent_state.llm_config.model, - instructions=request.instructions, - ) - - -@router.post("/runs", tags=["runs"], response_model=OpenAIRun) -def create_thread_and_run( - request: CreateThreadRunRequest = Body(...), -): - # TODO: add a bunch of messages and execute - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun]) -def list_runs( - thread_id: str = Path(..., description="The unique identifier of the thread."), - limit: int = Query(1000, description="How many runs to retrieve."), - order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."), - after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), - before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), -): - # TODO: store run information in a DB so it can be returned here - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep]) -def list_run_steps( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - limit: int = Query(1000, description="How many run steps to retrieve."), - order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."), - after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), - before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), -): - # TODO: store run information in a DB so it can be returned here - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun) -def retrieve_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), -): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep) -def retrieve_run_step( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - step_id: str = Path(..., description="The unique identifier of the run step."), -): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun) -def modify_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - request: ModifyRunRequest = Body(...), -): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun) -def submit_tool_outputs_to_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - request: SubmitToolOutputsToRunRequest = Body(...), -): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - -@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun) -def cancel_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), -): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 47031042..3dc7916a 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -36,7 +36,7 @@ async def create_chat_completion( The bearer token will be used to identify the user. The 'user' field in the completion_request should be set to the agent ID. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) agent_id = completion_request.user if agent_id is None: diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 06b0acd6..fc5ce507 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -17,7 +17,8 @@ from fastapi.responses import JSONResponse, StreamingResponse from pydantic import Field from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG -from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from letta.orm.errors import NoResultFound +from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate Block, BlockUpdate, @@ -54,23 +55,38 @@ from letta.server.server import SyncServer router = APIRouter(prefix="/agents", tags=["agents"]) +# TODO: This should be paginated @router.get("/", response_model=List[AgentState], operation_id="list_agents") def list_agents( name: Optional[str] = Query(None, description="Name of the agent"), tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"), + match_all_tags: bool = Query( + False, + description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.", + ), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), + # Extract user_id from header, default to None if not present ): """ List all agents associated with a given user. This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - agents = server.list_agents(user_id=actor.id, tags=tags) - # TODO: move this logic to the ORM - if name: - agents = [a for a in agents if a.name == name] + # Use dictionary comprehension to build kwargs dynamically + kwargs = { + key: value + for key, value in { + "tags": tags, + "match_all_tags": match_all_tags, + "name": name, + }.items() + if value is not None + } + + # Call list_agents with the dynamic kwargs + agents = server.agent_manager.list_agents(actor=actor, **kwargs) return agents @@ -83,7 +99,7 @@ def get_agent_context_window( """ Retrieve the context window of a specific agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id) @@ -106,20 +122,20 @@ def create_agent( """ Create a new agent with the specified configuration. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.create_agent(agent, actor=actor) @router.patch("/{agent_id}", response_model=AgentState, operation_id="update_agent") def update_agent( agent_id: str, - update_agent: UpdateAgentState = Body(...), + update_agent: UpdateAgent = Body(...), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Update an exsiting agent""" - actor = server.get_user_or_default(user_id=user_id) - return server.update_agent(update_agent, actor=actor) + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.update_agent(agent_id, update_agent, actor=actor) @router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent") @@ -129,7 +145,7 @@ def get_tools_from_agent( user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Get tools from an existing agent""" - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id) @@ -141,7 +157,7 @@ def add_tool_to_agent( user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Add tools to an existing agent""" - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) @@ -153,7 +169,7 @@ def remove_tool_from_agent( user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Add tools to an existing agent""" - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) @@ -166,13 +182,12 @@ def get_agent_state( """ Get the state of the agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id): - # agent does not exist - raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") - - return server.get_agent_state(user_id=actor.id, agent_id=agent_id) + try: + return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) @router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent") @@ -184,38 +199,37 @@ def delete_agent( """ Delete an agent. """ - actor = server.get_user_or_default(user_id=user_id) - - agent = server.get_agent(agent_id) - if not agent: - raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") - - server.delete_agent(user_id=actor.id, agent_id=agent_id) - return agent + actor = server.user_manager.get_user_or_default(user_id=user_id) + try: + return server.agent_manager.delete_agent(agent_id=agent_id, actor=actor) + except NoResultFound: + raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.") @router.get("/{agent_id}/sources", response_model=List[Source], operation_id="get_agent_sources") def get_agent_sources( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get the sources associated with an agent. """ - - return server.list_attached_sources(agent_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor) @router.get("/{agent_id}/memory/messages", response_model=List[Message], operation_id="list_agent_in_context_messages") def get_agent_in_context_messages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the messages in the context of a specific agent. """ - - return server.get_in_context_messages(agent_id=agent_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.get_in_context_messages(agent_id=agent_id, actor=actor) # TODO: remove? can also get with agent blocks @@ -223,13 +237,15 @@ def get_agent_in_context_messages( def get_agent_memory( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memory state of a specific agent. This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. """ + actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.get_agent_memory(agent_id=agent_id) + return server.get_agent_memory(agent_id=agent_id, actor=actor) @router.get("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="get_agent_memory_block") @@ -242,10 +258,12 @@ def get_agent_memory_block( """ Retrieve a memory block from an agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label) - return server.block_manager.get_block_by_id(block_id, actor=actor) + try: + return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) @router.get("/{agent_id}/memory/block", response_model=List[Block], operation_id="get_agent_memory_blocks") @@ -257,9 +275,12 @@ def get_agent_memory_blocks( """ Retrieve the memory blocks of a specific agent. """ - actor = server.get_user_or_default(user_id=user_id) - block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id) - return [server.block_manager.get_block_by_id(block_id, actor=actor) for block_id in block_ids] + actor = server.user_manager.get_user_or_default(user_id=user_id) + try: + agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor) + return agent.memory.blocks + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) @router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block") @@ -272,16 +293,17 @@ def add_agent_memory_block( """ Creates a memory block and links it to the agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) # Copied from POST /blocks + # TODO: Should have block_manager accept only CreateBlock + # TODO: This will be possible once we move ID creation to the ORM block_req = Block(**create_block.model_dump()) block = server.block_manager.create_or_update_block(actor=actor, block=block_req) # Link the block to the agent - updated_memory = server.link_block_to_agent_memory(user_id=actor.id, agent_id=agent_id, block_id=block.id) - - return updated_memory + agent = server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=actor) + return agent.memory @router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block_by_label") @@ -296,56 +318,56 @@ def remove_agent_memory_block( """ Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) # Unlink the block from the agent - updated_memory = server.unlink_block_from_agent_memory(user_id=actor.id, agent_id=agent_id, block_label=block_label) + agent = server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) - return updated_memory + return agent.memory @router.patch("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="update_agent_memory_block_by_label") def update_agent_memory_block( agent_id: str, block_label: str, - update_block: BlockUpdate = Body(...), + block_update: BlockUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - # get the block_id from the label - block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label) - - # update the block - return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor) + block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) + return server.block_manager.update_block(block.id, block_update=block_update, actor=actor) @router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary") def get_agent_recall_memory_summary( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the summary of the recall memory of a specific agent. """ + actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.get_recall_memory_summary(agent_id=agent_id) + return server.get_recall_memory_summary(agent_id=agent_id, actor=actor) @router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary, operation_id="get_agent_archival_memory_summary") def get_agent_archival_memory_summary( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the summary of the archival memory of a specific agent. """ - - return server.get_archival_memory_summary(agent_id=agent_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.get_archival_memory_summary(agent_id=agent_id, actor=actor) @router.get("/{agent_id}/archival", response_model=List[Passage], operation_id="list_agent_archival_memory") @@ -360,7 +382,7 @@ def get_agent_archival_memory( """ Retrieve the memories in an agent's archival memory store (paginated query). """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) # TODO need to add support for non-postgres here # chroma will throw: @@ -369,7 +391,7 @@ def get_agent_archival_memory( return server.get_agent_archival_cursor( user_id=actor.id, agent_id=agent_id, - cursor=after, # TODO: deleting before, after. is this expected? + cursor=after, # TODO: deleting before, after. is this expected? limit=limit, ) @@ -384,9 +406,9 @@ def insert_agent_archival_memory( """ Insert a memory into an agent's archival memory store. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text) + return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor) # TODO(ethan): query or path parameter for memory_id? @@ -402,9 +424,9 @@ def delete_agent_archival_memory( """ Delete a memory from an agent's archival memory store. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id) + server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) @@ -429,7 +451,7 @@ def get_agent_messages( """ Retrieve message history for an agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.get_agent_recall_cursor( user_id=actor.id, @@ -449,11 +471,13 @@ def update_message( message_id: str, request: MessageUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the details of a message associated with an agent. """ - return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request) + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor) @router.post( @@ -471,11 +495,11 @@ async def send_message( Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) result = await send_message_to_agent( server=server, agent_id=agent_id, - user_id=actor.id, + actor=actor, messages=request.messages, stream_steps=False, stream_tokens=False, @@ -511,11 +535,11 @@ async def send_message_streaming( It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) result = await send_message_to_agent( server=server, agent_id=agent_id, - user_id=actor.id, + actor=actor, messages=request.messages, stream_steps=True, stream_tokens=request.stream_tokens, @@ -531,7 +555,6 @@ async def process_message_background( server: SyncServer, actor: User, agent_id: str, - user_id: str, messages: list, assistant_message_tool_name: str, assistant_message_tool_kwarg: str, @@ -542,7 +565,7 @@ async def process_message_background( result = await send_message_to_agent( server=server, agent_id=agent_id, - user_id=user_id, + actor=actor, messages=messages, stream_steps=False, # NOTE(matt) stream_tokens=False, @@ -585,7 +608,7 @@ async def send_message_async( Asynchronously process a user message and return a job ID. The actual processing happens in the background, and the status can be checked using the job ID. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) # Create a new job job = Job( @@ -605,7 +628,6 @@ async def send_message_async( server=server, actor=actor, agent_id=agent_id, - user_id=actor.id, messages=request.messages, assistant_message_tool_name=request.assistant_message_tool_name, assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, @@ -618,7 +640,7 @@ async def send_message_async( async def send_message_to_agent( server: SyncServer, agent_id: str, - user_id: str, + actor: User, # role: MessageRole, messages: Union[List[Message], List[MessageCreate]], stream_steps: bool, @@ -645,8 +667,7 @@ async def send_message_to_agent( # Get the generator object off of the agent's streaming interface # This will be attached to the POST SSE request used under-the-hood - # letta_agent = server.load_agent(agent_id=agent_id) - letta_agent = server.load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id, actor=actor) # Disable token streaming if not OpenAI # TODO: cleanup this logic @@ -685,7 +706,7 @@ async def send_message_to_agent( task = asyncio.create_task( asyncio.to_thread( server.send_messages, - user_id=user_id, + actor=actor, agent_id=agent_id, messages=messages, interface=streaming_interface, diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index f58172d6..d9213233 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -1,10 +1,9 @@ from typing import TYPE_CHECKING, List, Optional -from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response from letta.orm.errors import NoResultFound from letta.schemas.block import Block, BlockUpdate, CreateBlock -from letta.schemas.memory import Memory from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -23,7 +22,7 @@ def list_blocks( server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name) @@ -33,7 +32,7 @@ def create_block( server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) block = Block(**create_block.model_dump()) return server.block_manager.create_or_update_block(actor=actor, block=block) @@ -41,12 +40,12 @@ def create_block( @router.patch("/{block_id}", response_model=Block, operation_id="update_memory_block") def update_block( block_id: str, - update_block: BlockUpdate = Body(...), + block_update: BlockUpdate = Body(...), server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.get_user_or_default(user_id=user_id) - return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor) + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor) @router.delete("/{block_id}", response_model=Block, operation_id="delete_memory_block") @@ -55,7 +54,7 @@ def delete_block( server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.block_manager.delete_block(block_id=block_id, actor=actor) @@ -66,7 +65,7 @@ def get_block( user_id: Optional[str] = Header(None, alias="user_id"), ): print("call get block", block_id) - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) try: block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) if block is None: @@ -76,7 +75,7 @@ def get_block( raise HTTPException(status_code=404, detail="Block not found") -@router.patch("/{block_id}/attach", response_model=Block, operation_id="link_agent_memory_block") +@router.patch("/{block_id}/attach", response_model=None, status_code=204, operation_id="link_agent_memory_block") def link_agent_memory_block( block_id: str, agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), @@ -86,17 +85,16 @@ def link_agent_memory_block( """ Link a memory block to an agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) - if block is None: - raise HTTPException(status_code=404, detail="Block not found") - - server.blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block.label) - return block + try: + server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor) + return Response(status_code=204) + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) -@router.patch("/{block_id}/detach", response_model=Memory, operation_id="unlink_agent_memory_block") +@router.patch("/{block_id}/detach", response_model=None, status_code=204, operation_id="unlink_agent_memory_block") def unlink_agent_memory_block( block_id: str, agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), @@ -106,11 +104,10 @@ def unlink_agent_memory_block( """ Unlink a memory block from an agent """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) - block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) - if block is None: - raise HTTPException(status_code=404, detail="Block not found") - # Link the block to the agent - server.blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id) - return block + try: + server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor) + return Response(status_code=204) + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index e726062f..4245d2f9 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -20,7 +20,7 @@ def list_jobs( """ List all jobs. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) # TODO: add filtering by status jobs = server.job_manager.list_jobs(actor=actor) @@ -40,7 +40,7 @@ def list_active_jobs( """ List all active jobs. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running]) @@ -54,7 +54,7 @@ def get_job( """ Get the status of a job. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) try: return server.job_manager.get_job_by_id(job_id=job_id, actor=actor) @@ -71,7 +71,7 @@ def delete_job( """ Delete a job by its job_id. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) try: job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor) diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index b276e339..bf06bae7 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -25,7 +25,7 @@ def create_sandbox_config( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor) @@ -35,7 +35,7 @@ def create_default_e2b_sandbox_config( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor) @@ -44,7 +44,7 @@ def create_default_local_sandbox_config( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor) @@ -55,7 +55,7 @@ def update_sandbox_config( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor) @@ -65,7 +65,7 @@ def delete_sandbox_config( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor) @@ -76,7 +76,7 @@ def list_sandbox_configs( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, cursor=cursor) @@ -90,7 +90,7 @@ def create_sandbox_env_var( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor) @@ -101,7 +101,7 @@ def update_sandbox_env_var( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor) @@ -111,7 +111,7 @@ def delete_sandbox_env_var( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor) @@ -123,5 +123,5 @@ def list_sandbox_env_vars( server: SyncServer = Depends(get_letta_server), user_id: str = Depends(get_user_id), ): - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, cursor=cursor) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 6b45e1d0..bcc3203d 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -36,7 +36,7 @@ def get_source( """ Get all sources """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) if not source: @@ -53,7 +53,7 @@ def get_source_id_by_name( """ Get a source by name """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor) if not source: @@ -69,7 +69,7 @@ def list_sources( """ List all data sources created by a user. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.list_all_sources(actor=actor) @@ -83,7 +83,7 @@ def create_source( """ Create a new data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) source = Source(**source_create.model_dump()) return server.source_manager.create_source(source=source, actor=actor) @@ -99,7 +99,7 @@ def update_source( """ Update the name or documentation of an existing data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor): raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.") return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor) @@ -114,7 +114,7 @@ def delete_source( """ Delete a data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) server.delete_source(source_id=source_id, actor=actor) @@ -129,7 +129,7 @@ def attach_source_to_agent( """ Attach a data source to an existing agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." @@ -147,7 +147,7 @@ def detach_source_from_agent( """ Detach a data source from an existing agent. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id) @@ -163,7 +163,7 @@ def upload_file_to_source( """ Upload a file to a data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." @@ -197,7 +197,7 @@ def list_passages( """ List all passages associated with a data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id) return passages @@ -213,7 +213,7 @@ def list_files_from_source( """ List paginated files associated with a data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor) @@ -229,7 +229,7 @@ def delete_file_from_source( """ Delete a data source. """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor) if deleted_file is None: diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index d288ca02..e1eb5919 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -25,7 +25,7 @@ def delete_tool( """ Delete a tool by name """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor) @@ -38,7 +38,7 @@ def get_tool( """ Get a tool by ID """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) if tool is None: # return 404 error @@ -55,7 +55,7 @@ def get_tool_id( """ Get a tool ID by name """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) if tool: return tool.id @@ -74,7 +74,7 @@ def list_tools( Get a list of all tools available to agents belonging to the org of the user """ try: - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit) except Exception as e: # Log or print the full exception here for debugging @@ -92,7 +92,7 @@ def create_tool( Create a new tool """ try: - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) tool = Tool(**request.model_dump()) return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor) except UniqueConstraintViolationError as e: @@ -124,7 +124,7 @@ def upsert_tool( Create or update a tool """ try: - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor) return tool except UniqueConstraintViolationError as e: @@ -147,7 +147,7 @@ def update_tool( """ Update an existing tool """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor) @@ -159,7 +159,7 @@ def add_base_tools( """ Add base tools """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) return server.tool_manager.add_base_tools(actor=actor) @@ -173,7 +173,7 @@ def add_base_tools( # """ # Run an existing tool on provided arguments # """ -# actor = server.get_user_or_default(user_id=user_id) +# actor = server.user_manager.get_user_or_default(user_id=user_id) # return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id) @@ -187,7 +187,7 @@ def run_tool_from_source( """ Attempt to build a tool from source, then run it on the provided arguments """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) try: return server.run_tool_from_source( @@ -220,7 +220,7 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id: """ Get a list of all Composio apps """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) composio_api_key = get_composio_key(server, actor=actor) return server.get_composio_apps(api_key=composio_api_key) @@ -234,7 +234,7 @@ def list_composio_actions_by_app( """ Get a list of all Composio actions for a specific app """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) composio_api_key = get_composio_key(server, actor=actor) return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name, api_key=composio_api_key) @@ -248,7 +248,7 @@ def add_composio_tool( """ Add a new Composio tool by action name (Composio refers to each tool as an `Action`) """ - actor = server.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=user_id) composio_api_key = get_composio_key(server, actor=actor) try: diff --git a/letta/server/server.py b/letta/server/server.py index 9c66832e..d0c1d062 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -19,7 +19,7 @@ from letta.agent import Agent, save_agent from letta.chat_only_agent import ChatOnlyAgent from letta.credentials import LettaCredentials from letta.data_sources.connectors import DataConnector, load_data -from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError +from letta.errors import LettaAgentNotFoundError # TODO use custom interface from letta.interface import AgentInterface # abstract @@ -30,7 +30,6 @@ from letta.o1_agent import O1Agent from letta.offline_memory_agent import OfflineMemoryAgent from letta.orm import Base from letta.orm.errors import NoResultFound -from letta.prompts import gpt_system from letta.providers import ( AnthropicProvider, AzureProvider, @@ -44,15 +43,9 @@ from letta.providers import ( VLLMChatCompletionsProvider, VLLMCompletionsProvider, ) -from letta.schemas.agent import ( - AgentState, - AgentType, - CreateAgent, - PersistedAgentState, - UpdateAgentState, -) +from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent from letta.schemas.api_key import APIKey, APIKeyCreate -from letta.schemas.block import Block, BlockUpdate +from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig # openai schemas @@ -68,14 +61,13 @@ from letta.schemas.memory import ( ) from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.organization import Organization -from letta.schemas.passage import Passage as PydanticPassage +from letta.schemas.passage import Passage from letta.schemas.source import Source from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics -from letta.schemas.user import User as PydanticUser -from letta.services.agents_tags_manager import AgentsTagsManager +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager -from letta.services.blocks_agents_manager import BlocksAgentsManager from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager @@ -85,9 +77,8 @@ from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager -from letta.services.tools_agents_manager import ToolsAgentsManager from letta.services.user_manager import UserManager -from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads +from letta.utils import get_utc_time, json_dumps, json_loads logger = get_logger(__name__) @@ -105,18 +96,13 @@ class Server(object): """Return the memory of an agent (core memory + non-core statistics)""" raise NotImplementedError - @abstractmethod - def get_agent_state(self, user_id: str, agent_id: str) -> dict: - """Return the config of an agent""" - raise NotImplementedError - @abstractmethod def get_server_config(self, user_id: str) -> dict: """Return the base config""" raise NotImplementedError @abstractmethod - def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict: + def update_agent_core_memory(self, user_id: str, agent_id: str, label: str, actor: User) -> Memory: """Update the agents core memory block, return the new state""" raise NotImplementedError @@ -124,7 +110,7 @@ class Server(object): def create_agent( self, request: CreateAgent, - actor: PydanticUser, + actor: User, # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: @@ -270,10 +256,6 @@ class SyncServer(Server): # auth_mode: str = "none", # "none, "jwt", "external" ): """Server process holds in-memory agents that are being run""" - - # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts - self.active_agents = [] - # chaining = whether or not to run again if request_heartbeat=true self.chaining = chaining @@ -307,12 +289,10 @@ class SyncServer(Server): self.tool_manager = ToolManager() self.block_manager = BlockManager() self.source_manager = SourceManager() - self.agents_tags_manager = AgentsTagsManager() self.sandbox_config_manager = SandboxConfigManager(tool_settings) - self.blocks_agents_manager = BlocksAgentsManager() self.message_manager = MessageManager() - self.tools_agents_manager = ToolsAgentsManager() self.job_manager = JobManager() + self.agent_manager = AgentManager() # Managers that interface with parallelism self.per_agent_lock_manager = PerAgentLockManager() @@ -397,42 +377,9 @@ class SyncServer(Server): ) ) - def save_agents(self): - """Saves all the agents that are in the in-memory object store""" - for agent_d in self.active_agents: - try: - save_agent(agent_d["agent"], self.ms) - logger.info(f"Saved agent {agent_d['agent_id']}") - except Exception as e: - logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}") - - def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]: - """Get the agent object from the in-memory object store""" - for d in self.active_agents: - if d["user_id"] == str(user_id) and d["agent_id"] == str(agent_id): - return d["agent"] - return None - - def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None: - """Put an agent object inside the in-memory object store""" - # Make sure the agent doesn't already exist - if self._get_agent(user_id=user_id, agent_id=agent_id) is not None: - # Can be triggered on concucrent request, so don't throw a full error - logger.exception(f"Agent (user={user_id}, agent={agent_id}) is already loaded") - return - # Add Agent instance to the in-memory list - self.active_agents.append( - { - "user_id": str(user_id), - "agent_id": str(agent_id), - "agent": agent_obj, - } - ) - - def initialize_agent(self, agent_id, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent: + def initialize_agent(self, agent_id, actor, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent: """Initialize an agent from the database""" - agent_state = self.get_agent(agent_id=agent_id) - actor = self.user_manager.get_user_by_id(user_id=agent_state.user_id) + agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) interface = interface or self.default_interface_factory() if agent_state.agent_type == AgentType.memgpt_agent: @@ -446,19 +393,20 @@ class SyncServer(Server): agent = O1Agent(agent_state=agent_state, interface=interface, user=actor) # Persist to agent - save_agent(agent, self.ms) + save_agent(agent) return agent - def load_agent(self, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent: + def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" agent_lock = self.per_agent_lock_manager.get_lock(agent_id) with agent_lock: - agent_state = self.get_agent(agent_id=agent_id) + agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + if agent_state is None: raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist") - elif agent_state.user_id is None: + elif agent_state.created_by_id is None: raise ValueError(f"Agent (agent_id={agent_id}) does not have a user_id") - actor = self.user_manager.get_user_by_id(user_id=agent_state.user_id) + actor = self.user_manager.get_user_by_id(user_id=agent_state.created_by_id) interface = interface or self.default_interface_factory() if agent_state.agent_type == AgentType.memgpt_agent: @@ -476,19 +424,19 @@ class SyncServer(Server): agent.rebuild_system_prompt() # Persist to agent - save_agent(agent, self.ms) + save_agent(agent) return agent def _step( self, - user_id: str, + actor: User, agent_id: str, input_messages: Union[Message, List[Message]], interface: Union[AgentInterface, None] = None, # needed to getting responses # timestamp: Optional[datetime], ) -> LettaUsageStatistics: """Send the input message through the agent""" - + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user # Input validation if isinstance(input_messages, Message): input_messages = [input_messages] @@ -498,10 +446,7 @@ class SyncServer(Server): logger.debug(f"Got input messages: {input_messages}") letta_agent = None try: - - # Get the agent object (loaded in memory) - # letta_agent = self._get_or_load_agent(agent_id=agent_id) - letta_agent = self.load_agent(agent_id=agent_id, interface=interface) + letta_agent = self.load_agent(agent_id=agent_id, interface=interface, actor=actor) if letta_agent is None: raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") @@ -519,7 +464,7 @@ class SyncServer(Server): ) # save agent after step - save_agent(letta_agent, self.ms) + save_agent(letta_agent) except Exception as e: logger.error(f"Error in server._step: {e}") @@ -534,11 +479,13 @@ class SyncServer(Server): def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: """Process a CLI command""" + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) logger.debug(f"Got command: {command}") # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) usage = None if command.lower() == "exit": @@ -546,7 +493,7 @@ class SyncServer(Server): raise ValueError(command) elif command.lower() == "save" or command.lower() == "savechat": - save_agent(letta_agent, self.ms) + save_agent(letta_agent) elif command.lower() == "attach": # Different from CLI, we extract the data source name from the command @@ -560,8 +507,8 @@ class SyncServer(Server): letta_agent.attach_source( user=self.user_manager.get_user_by_id(user_id=user_id), source_id=data_source, - source_manager=letta_agent.source_manager, - ms=self.ms, + source_manager=self.source_manager, + agent_manager=self.agent_manager, ) elif command.lower() == "dump" or command.lower().startswith("dump "): @@ -637,11 +584,11 @@ class SyncServer(Server): elif command.lower() == "heartbeat": input_message = system.get_heartbeat() - usage = self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) + usage = self._step(actor=actor, agent_id=agent_id, input_message=input_message) elif command.lower() == "memorywarning": input_message = system.get_token_limit_warning() - usage = self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) + usage = self._step(actor=actor, agent_id=agent_id, input_message=input_message) if not usage: usage = LettaUsageStatistics() @@ -656,9 +603,14 @@ class SyncServer(Server): timestamp: Optional[datetime] = None, ) -> LettaUsageStatistics: """Process an incoming user message and feed it through the Letta agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: + try: + actor = self.user_manager.get_user_by_id(user_id=user_id) + except NoResultFound: raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: + + try: + agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + except NoResultFound: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Basic input sanitization @@ -692,7 +644,7 @@ class SyncServer(Server): ) # Run the agent state forward - usage = self._step(user_id=user_id, agent_id=agent_id, input_messages=message) + usage = self._step(actor=actor, agent_id=agent_id, input_messages=message) return usage def system_message( @@ -703,9 +655,14 @@ class SyncServer(Server): timestamp: Optional[datetime] = None, ) -> LettaUsageStatistics: """Process an incoming system message and feed it through the Letta agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: + try: + actor = self.user_manager.get_user_by_id(user_id=user_id) + except NoResultFound: raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: + + try: + agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + except NoResultFound: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Basic input sanitization @@ -752,11 +709,11 @@ class SyncServer(Server): message.created_at = timestamp # Run the agent state forward - return self._step(user_id=user_id, agent_id=agent_id, input_messages=message) + return self._step(actor=actor, agent_id=agent_id, input_messages=message) def send_messages( self, - user_id: str, + actor: User, agent_id: str, messages: Union[List[MessageCreate], List[Message]], # whether or not to wrap user and system message as MemGPT-style stringified JSON @@ -771,11 +728,6 @@ class SyncServer(Server): Otherwise, we can pass them in directly. """ - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") - message_objects: List[Message] = [] if all(isinstance(m, MessageCreate) for m in messages): @@ -814,16 +766,11 @@ class SyncServer(Server): raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(message) for message in messages]}") # Run the agent state forward - return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects, interface=interface) + return self._step(actor=actor, agent_id=agent_id, input_messages=message_objects, interface=interface) # @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: """Run a command on the agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") - # If the input begins with a command prefix, attempt to process it as a command if command.startswith("/"): if len(command) > 1: @@ -833,86 +780,16 @@ class SyncServer(Server): def create_agent( self, request: CreateAgent, - actor: PydanticUser, + actor: User, # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: """Create a new agent using a config""" - user_id = actor.id - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - - if interface is None: - interface = self.default_interface_factory() - - # create agent name - if request.name is None: - request.name = create_random_username() - - if request.agent_type is None: - request.agent_type = AgentType.memgpt_agent - - # system debug - if request.system is None: - # TODO: don't hardcode - if request.agent_type == AgentType.memgpt_agent: - request.system = gpt_system.get_system_text("memgpt_chat") - elif request.agent_type == AgentType.o1_agent: - request.system = gpt_system.get_system_text("memgpt_modified_o1") - elif request.agent_type == AgentType.offline_memory_agent: - request.system = gpt_system.get_system_text("memgpt_offline_memory") - elif request.agent_type == AgentType.chat_only_agent: - request.system = gpt_system.get_system_text("memgpt_convo_only") - else: - raise ValueError(f"Invalid agent type: {request.agent_type}") - - # create blocks (note: cannot be linked into the agent_id is created) - blocks = [] - for create_block in request.memory_blocks: - block = self.block_manager.create_or_update_block(Block(**create_block.model_dump()), actor=actor) - blocks.append(block) - - # get tools + only add if they exist - tool_objs = [] - if request.tools: - for tool_name in request.tools: - tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) - if tool_obj: - tool_objs.append(tool_obj) - else: - warnings.warn(f"Attempted to add a nonexistent tool {tool_name} to agent {request.name}, skipping.") - # reset the request.tools to only valid tools - request.tools = [t.name for t in tool_objs] - - # get the user - logger.debug(f"Attempting to find user: {user_id}") - user = self.user_manager.get_user_by_id(user_id=user_id) - if not user: - raise ValueError(f"cannot find user with associated client id: {user_id}") - - if request.llm_config is None: - raise ValueError("llm_config is required") - - if request.embedding_config is None: - raise ValueError("embedding_config is required") - - # created and persist the agent state in the DB - agent_state = PersistedAgentState( - name=request.name, - user_id=user_id, - tool_names=request.tools if request.tools else [], - tool_rules=request.tool_rules, - agent_type=request.agent_type or AgentType.memgpt_agent, - llm_config=request.llm_config, - embedding_config=request.embedding_config, - system=request.system, - # other metadata - description=request.description, - metadata_=request.metadata_, + # Invoke manager + agent_state = self.agent_manager.create_agent( + agent_create=request, + actor=actor, ) - # TODO: move this to agent ORM - # this saves the agent ID and state into the DB - self.ms.create_agent(agent_state) # create the agent object if request.initial_message_sequence is not None: @@ -937,81 +814,29 @@ class SyncServer(Server): init_messages = None # initialize the agent (generates initial message list with system prompt) - self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages) + if interface is None: + interface = self.default_interface_factory() + self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages, actor=actor) - # Note: mappings (e.g. tags, blocks) are created after the agent is persisted - # TODO: add source mappings here as well - - # create the tags - if request.tags: - for tag in request.tags: - self.agents_tags_manager.add_tag_to_agent(agent_id=agent_state.id, tag=tag, actor=actor) - - # create block mappins (now that agent is persisted) - for block in blocks: - # this links the created block to the agent - self.blocks_agents_manager.add_block_to_agent(block_id=block.id, agent_id=agent_state.id, block_label=block.label) - - in_memory_agent_state = self.get_agent(agent_state.id) + in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor) return in_memory_agent_state - def get_agent(self, agent_id: str) -> Optional[AgentState]: - """ - Retrieve the full agent state from the DB. - This gathers data accross multiple tables to provide the full state of an agent, which is passed into the `Agent` object for creation. - """ - - # get data persisted from the DB - agent_state = self.ms.get_agent(agent_id=agent_id) - if agent_state is None: - # agent does not exist - return None - if agent_state.user_id is None: - raise ValueError(f"Agent {agent_id} does not have a user_id") - user = self.user_manager.get_user_by_id(user_id=agent_state.user_id) - - # construct the in-memory, full agent state - this gather data stored in different tables but that needs to be passed to `Agent` - # we also return this data to the user to provide all the state related to an agent - - # get `Memory` object by getting the linked block IDs and fetching the blocks, then putting that into a `Memory` object - # this is the "in memory" representation of the in-context memory - block_ids = self.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id) - blocks = [] - for block_id in block_ids: - block = self.block_manager.get_block_by_id(block_id=block_id, actor=user) - assert block, f"Block with ID {block_id} does not exist" - blocks.append(block) - memory = Memory(blocks=blocks) - - # get `Tool` objects - tools = [self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=user) for tool_name in agent_state.tool_names] - tools = [tool for tool in tools if tool is not None] - - # get `Source` objects - sources = self.list_attached_sources(agent_id=agent_id) - - # get the tags - tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user) - - # return the full agent state - this contains all data needed to recreate the agent - return AgentState(**agent_state.model_dump(), memory=memory, tools=tools, sources=sources, tags=tags) - + # TODO: This is not good! + # TODO: Ideally, this should ALL be handled by the ORM + # TODO: The main blocker here IS the _message updates def update_agent( self, - request: UpdateAgentState, - actor: PydanticUser, + agent_id: str, + request: UpdateAgent, + actor: User, ) -> AgentState: """Update the agents core memory block, return the new state""" - try: - self.user_manager.get_user_by_id(user_id=actor.id) - except Exception: - raise ValueError(f"User user_id={actor.id} does not exist") - - if self.ms.get_agent(agent_id=request.id) is None: - raise ValueError(f"Agent agent_id={request.id} does not exist") - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=request.id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) + + # Update tags + if request.tags is not None: # Allow for empty list + letta_agent.agent_state.tags = request.tags # update the system prompt if request.system: @@ -1025,30 +850,27 @@ class SyncServer(Server): letta_agent.set_message_buffer(message_ids=request.message_ids) # tools - if request.tool_names: + if request.tool_ids: # Replace tools and also re-link # (1) get tools + make sure they exist # Current and target tools as sets of tool names - current_tools = set(letta_agent.agent_state.tool_names) - target_tools = set(request.tool_names) + current_tools = letta_agent.agent_state.tools + current_tool_ids = set([t.id for t in current_tools]) + target_tool_ids = set(request.tool_ids) # Calculate tools to add and remove - tools_to_add = target_tools - current_tools - tools_to_remove = current_tools - target_tools - - # Fetch tool objects for those to add and remove - tools_to_add = [self.tool_manager.get_tool_by_name(tool_name=tool, actor=actor) for tool in tools_to_add] - tools_to_remove = [self.tool_manager.get_tool_by_name(tool_name=tool, actor=actor) for tool in tools_to_remove] + tool_ids_to_add = target_tool_ids - current_tool_ids + tools_ids_to_remove = current_tool_ids - target_tool_ids # update agent tool list - for tool in tools_to_remove: - self.remove_tool_from_agent(agent_id=request.id, tool_id=tool.id, user_id=actor.id) - for tool in tools_to_add: - self.add_tool_to_agent(agent_id=request.id, tool_id=tool.id, user_id=actor.id) + for tool_id in tools_ids_to_remove: + self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) + for tool_id in tool_ids_to_add: + self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) # reload agent - letta_agent = self.load_agent(agent_id=request.id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # configs if request.llm_config: @@ -1062,33 +884,18 @@ class SyncServer(Server): if request.metadata_: letta_agent.agent_state.metadata_ = request.metadata_ - # Manage tag state - if request.tags is not None: - current_tags = set(self.agents_tags_manager.get_tags_for_agent(agent_id=letta_agent.agent_state.id, actor=actor)) - target_tags = set(request.tags) - - tags_to_add = target_tags - current_tags - tags_to_remove = current_tags - target_tags - - for tag in tags_to_add: - self.agents_tags_manager.add_tag_to_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor) - for tag in tags_to_remove: - self.agents_tags_manager.delete_tag_from_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor) - # save the agent - save_agent(letta_agent, self.ms) + save_agent(letta_agent) # TODO: probably reload the agent somehow? return letta_agent.agent_state def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]: """Get tools from an existing agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) return letta_agent.agent_state.tools def add_tool_to_agent( @@ -1098,25 +905,20 @@ class SyncServer(Server): user_id: str, ): """Add tools from an existing agent""" - try: - user = self.user_manager.get_user_by_id(user_id=user_id) - except NoResultFound: - raise ValueError(f"User user_id={user_id} does not exist") - - if self.ms.get_agent(agent_id=agent_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # Get all the tool objects from the request tool_objs = [] - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) assert tool_obj, f"Tool with id={tool_id} does not exist" tool_objs.append(tool_obj) for tool in letta_agent.agent_state.tools: - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor) assert tool_obj, f"Tool with id={tool.id} does not exist" # If it's not the already added tool @@ -1124,13 +926,13 @@ class SyncServer(Server): tool_objs.append(tool_obj) # replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tool_names = [tool.name for tool in tool_objs] + letta_agent.agent_state.tools = tool_objs # then attempt to link the tools modules letta_agent.link_tools(tool_objs) # save the agent - save_agent(letta_agent, self.ms) + save_agent(letta_agent) return letta_agent.agent_state def remove_tool_from_agent( @@ -1140,21 +942,16 @@ class SyncServer(Server): user_id: str, ): """Remove tools from an existing agent""" - try: - user = self.user_manager.get_user_by_id(user_id=user_id) - except NoResultFound: - raise ValueError(f"User user_id={user_id} does not exist") - - if self.ms.get_agent(agent_id=agent_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # Get all the tool_objs tool_objs = [] for tool in letta_agent.agent_state.tools: - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor) assert tool_obj, f"Tool with id={tool.id} does not exist" # If it's not the tool we want to remove @@ -1162,147 +959,47 @@ class SyncServer(Server): tool_objs.append(tool_obj) # replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tool_names = [tool.name for tool in tool_objs] + letta_agent.agent_state.tools = tool_objs # then attempt to link the tools modules letta_agent.link_tools(tool_objs) # save the agent - save_agent(letta_agent, self.ms) + save_agent(letta_agent) return letta_agent.agent_state - def get_agent_state(self, user_id: str, agent_id: str) -> AgentState: - # TODO: duplicate, remove - return self.get_agent(agent_id=agent_id) - - def list_agents(self, user_id: str, tags: Optional[List[str]] = None) -> List[AgentState]: - """List all available agents to a user""" - user = self.user_manager.get_user_by_id(user_id=user_id) - - if tags is None: - agents_states = self.ms.list_agents(user_id=user_id) - agent_ids = [agent.id for agent in agents_states] - else: - agent_ids = [] - for tag in tags: - agent_ids += self.agents_tags_manager.get_agents_by_tag(tag=tag, actor=user) - - return [self.get_agent(agent_id=agent_id) for agent_id in agent_ids] - # convert name->id - def get_agent_id(self, name: str, user_id: str): - agent_state = self.ms.get_agent(agent_name=name, user_id=user_id) - if not agent_state: - return None - return agent_state.id - - def get_source(self, source_id: str, user_id: str) -> Source: - existing_source = self.ms.get_source(source_id=source_id, user_id=user_id) - if not existing_source: - raise ValueError("Source does not exist") - return existing_source - - def get_source_id(self, source_name: str, user_id: str) -> str: - existing_source = self.ms.get_source(source_name=source_name, user_id=user_id) - if not existing_source: - raise ValueError("Source does not exist") - return existing_source.id - - def get_agent_memory(self, agent_id: str) -> Memory: + def get_agent_memory(self, agent_id: str, actor: User) -> Memory: """Return the memory of an agent (core memory)""" - agent = self.load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id, actor=actor) return agent.agent_state.memory - def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: - agent = self.load_agent(agent_id=agent_id) + def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary: + agent = self.load_agent(agent_id=agent_id, actor=actor) return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user)) - def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary: - agent = self.load_agent(agent_id=agent_id) + def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary: + agent = self.load_agent(agent_id=agent_id, actor=actor) return RecallMemorySummary(size=len(agent.message_manager)) - def get_in_context_message_ids(self, agent_id: str) -> List[str]: - """Get the message ids of the in-context messages in the agent's memory""" - # Get the agent object (loaded in memory) - agent = self.load_agent(agent_id=agent_id) - return [m.id for m in agent._messages] - - def get_in_context_messages(self, agent_id: str) -> List[Message]: + def get_in_context_messages(self, agent_id: str, actor: User) -> List[Message]: """Get the in-context messages in the agent's memory""" # Get the agent object (loaded in memory) - agent = self.load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id, actor=actor) return agent._messages - def get_agent_message(self, agent_id: str, message_id: str) -> Message: - """Get a single message from the agent's memory""" - # Get the agent object (loaded in memory) - agent = self.load_agent(agent_id=agent_id) - message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user) - return message - - def get_agent_messages( - self, - agent_id: str, - start: int, - count: int, - ) -> Union[List[Message], List[LettaMessage]]: - """Paginated query of all messages in agent message queue""" - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) - - if start < 0 or count < 0: - raise ValueError("Start and count values should be non-negative") - - if start + count < len(letta_agent._messages): # messages can be returned from whats in memory - # Reverse the list to make it in reverse chronological order - reversed_messages = letta_agent._messages[::-1] - # Check if start is within the range of the list - if start >= len(reversed_messages): - raise IndexError("Start index is out of range") - - # Calculate the end index, ensuring it does not exceed the list length - end_index = min(start + count, len(reversed_messages)) - - # Slice the list for pagination - messages = reversed_messages[start:end_index] - - else: - # need to access persistence manager for additional messages - - # get messages using message manager - page = letta_agent.message_manager.list_user_messages_for_agent( - agent_id=agent_id, - actor=self.default_user, - cursor=start, - limit=count, - ) - - messages = page - assert all(isinstance(m, Message) for m in messages) - - ## Convert to json - ## Add a tag indicating in-context or not - # json_messages = [record.to_json() for record in messages] - # in_context_message_ids = [str(m.id) for m in letta_agent._messages] - # for d in json_messages: - # d["in_context"] = True if str(d["id"]) in in_context_message_ids else False - - return messages - - def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[PydanticPassage]: + def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[Passage]: """Paginated query of all messages in agent archival memory""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # iterate over records records = letta_agent.passage_manager.list_passages( - actor=self.default_user, + actor=actor, agent_id=agent_id, cursor=cursor, limit=limit, @@ -1316,14 +1013,14 @@ class SyncServer(Server): agent_id: str, cursor: Optional[str] = None, limit: Optional[int] = 100, - ) -> List[PydanticPassage]: - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise LettaUserNotFoundError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise LettaAgentNotFoundError(f"Agent agent_id={agent_id} does not exist") + order_by: Optional[str] = "created_at", + reverse: Optional[bool] = False, + ) -> List[Passage]: + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # iterate over records records = letta_agent.passage_manager.list_passages( @@ -1334,32 +1031,22 @@ class SyncServer(Server): ) return records - def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[PydanticPassage]: - actor = self.user_manager.get_user_by_id(user_id=user_id) - if actor is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") - + def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # Insert into archival memory - return self.passage_manager.insert_passage( + passages = self.passage_manager.insert_passage( agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor ) - def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str): - actor = self.user_manager.get_user_by_id(user_id=user_id) - if actor is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + save_agent(letta_agent) - # TODO: should return a passage + return passages + def delete_archival_memory(self, agent_id: str, memory_id: str, actor: User): # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # Delete by ID # TODO check if it exists first, and throw error if not @@ -1379,14 +1066,11 @@ class SyncServer(Server): assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[List[Message], List[LettaMessage]]: - actor = self.user_manager.get_user_by_id(user_id=user_id) - if actor is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) # iterate over records start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None @@ -1441,95 +1125,19 @@ class SyncServer(Server): return response - def update_agent_core_memory(self, user_id: str, agent_id: str, label: str, value: str) -> Memory: + def update_agent_core_memory(self, agent_id: str, label: str, value: str, actor: User) -> Memory: """Update the value of a block in the agent's memory""" # get the block id - block = self.get_agent_block_by_label(user_id=user_id, agent_id=agent_id, label=label) - block_id = block.id + block = self.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=actor) # update the block - self.block_manager.update_block( - block_id=block_id, block_update=BlockUpdate(value=value), actor=self.user_manager.get_user_by_id(user_id=user_id) - ) + self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor) # load agent - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) return letta_agent.agent_state.memory - def rename_agent(self, user_id: str, agent_id: str, new_agent_name: str) -> PersistedAgentState: - """Update the name of the agent in the database""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") - - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) - - current_name = letta_agent.agent_state.name - if current_name == new_agent_name: - raise ValueError(f"New name ({new_agent_name}) is the same as the current name") - - try: - letta_agent.agent_state.name = new_agent_name - self.ms.update_agent(agent=letta_agent.agent_state) - except Exception as e: - logger.exception(f"Failed to update agent name with:\n{str(e)}") - raise ValueError(f"Failed to update agent name in database") - - assert isinstance(letta_agent.agent_state.id, str) - return letta_agent.agent_state - - def delete_agent(self, user_id: str, agent_id: str): - """Delete an agent in the database""" - actor = self.user_manager.get_user_by_id(user_id=user_id) - # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL - # TODO: EVENTUALLY WE GET AUTO-DELETES WHEN WE SPECIFY RELATIONSHIPS IN THE ORM - self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor) - self.blocks_agents_manager.remove_all_agent_blocks(agent_id=agent_id) - - # Verify that the agent exists and belongs to the org of the user - agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id) - if agent_state is None: - raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}") - - # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL - messages = self.message_manager.list_messages_for_agent(agent_id=agent_state.id) - for message in messages: - self.message_manager.delete_message_by_id(message.id, actor=actor) - - # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM - try: - agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id) - if agent_state_user.organization_id != actor.organization_id: - raise ValueError( - f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?" - ) - except NoResultFound: - logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}") - - # delete all passages associated with this agent - # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM - passages = self.passage_manager.list_passages(actor=actor, agent_id=agent_state.id) - for passage in passages: - self.passage_manager.delete_passage_by_id(passage.id, actor=actor) - - # First, if the agent is in the in-memory cache we should remove it - # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts - try: - self.active_agents = [d for d in self.active_agents if str(d["agent_id"]) != str(agent_id)] - except Exception as e: - logger.exception(f"Failed to delete agent {agent_id} from cache via ID with:\n{str(e)}") - raise ValueError(f"Failed to delete agent {agent_id} from cache") - - # Next, attempt to delete it from the actual database - try: - self.ms.delete_agent(agent_id=agent_id, per_agent_lock_manager=self.per_agent_lock_manager) - except Exception as e: - logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}") - raise ValueError(f"Failed to delete agent {agent_id} in database") - def api_key_to_user(self, api_key: str) -> str: """Decode an API key to a user""" token = self.ms.get_api_key(api_key=api_key) @@ -1557,7 +1165,7 @@ class SyncServer(Server): self.ms.delete_api_key(api_key=api_key) return api_key_obj - def delete_source(self, source_id: str, actor: PydanticUser): + def delete_source(self, source_id: str, actor: User): """Delete a data source""" self.source_manager.delete_source(source_id=source_id, actor=actor) @@ -1566,7 +1174,7 @@ class SyncServer(Server): # TODO: delete data from agent passage stores (?) - def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: PydanticUser) -> Job: + def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: # update job job = self.job_manager.get_job_by_id(job_id, actor=actor) @@ -1589,13 +1197,14 @@ class SyncServer(Server): self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) # update all agents who have this source attached - agent_ids = self.ms.list_attached_agents(source_id=source_id) - for agent_id in agent_ids: - agent = self.load_agent(agent_id=agent_id) + agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor) + for agent_state in agent_states: + agent_id = agent_state.id + agent = self.load_agent(agent_id=agent_id, actor=actor) curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id) - agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, ms=self.ms) + agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager) new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id) - assert new_passage_size >= curr_passage_size # in case empty files are added + assert new_passage_size >= curr_passage_size # in case empty files are added return job @@ -1626,21 +1235,22 @@ class SyncServer(Server): source_name: Optional[str] = None, ) -> Source: # attach a data source to an agent - user = self.user_manager.get_user_by_id(user_id=user_id) + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) if source_id: - data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=user) + data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor) elif source_name: - data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) + data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor) else: raise ValueError(f"Need to provide at least source_id or source_name to find the source.") assert data_source, f"Data source with id={source_id} or name={source_name} does not exist" # load agent - agent = self.load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id, actor=actor) # attach source to agent - agent.attach_source(user=user, source_id=data_source.id, source_manager=self.source_manager, ms=self.ms) + agent.attach_source(user=actor, source_id=data_source.id, source_manager=self.source_manager, agent_manager=self.agent_manager) return data_source @@ -1648,40 +1258,35 @@ class SyncServer(Server): self, user_id: str, agent_id: str, - # source_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None, ) -> Source: - user = self.user_manager.get_user_by_id(user_id=user_id) + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) if source_id: - source = self.source_manager.get_source_by_id(source_id=source_id, actor=user) + source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor) elif source_name: - source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) + source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor) else: raise ValueError(f"Need to provide at least source_id or source_name to find the source.") source_id = source.id + # TODO: This should be done with the ORM? # delete all Passage objects with source_id==source_id from agent's archival memory - agent = self.load_agent(agent_id=agent_id) - agent.passage_manager.delete_passages(actor=user, limit=100, source_id=source_id) + agent = self.load_agent(agent_id=agent_id, actor=actor) + agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id) # delete agent-source mapping - self.ms.detach_source(agent_id=agent_id, source_id=source_id) + self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor) # return back source data return source - def list_attached_sources(self, agent_id: str) -> List[Source]: - # list all attached sources to an agent - source_ids = self.ms.list_attached_source_ids(agent_id) - - return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids] - - def list_data_source_passages(self, user_id: str, source_id: str) -> List[PydanticPassage]: + def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] - def list_all_sources(self, actor: PydanticUser) -> List[Source]: + def list_all_sources(self, actor: User) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" sources = self.source_manager.list_sources(actor=actor) @@ -1699,15 +1304,9 @@ class SyncServer(Server): # num_documents = document_conn.size({"data_source": source.name}) num_documents = 0 - agent_ids = self.ms.list_attached_agents(source_id=source.id) + agents = self.source_manager.list_attached_agents(source_id=source.id, actor=actor) # add the agent name information - attached_agents = [ - { - "id": str(a_id), - "name": self.ms.get_agent(user_id=actor.id, agent_id=a_id).name, - } - for a_id in agent_ids - ] + attached_agents = [{"id": agent.id, "name": agent.name} for agent in agents] # Overwrite metadata field, should be empty anyways source.metadata_ = dict( @@ -1720,7 +1319,7 @@ class SyncServer(Server): return sources_with_metadata - def add_default_external_tools(self, actor: PydanticUser) -> bool: + def add_default_external_tools(self, actor: User) -> bool: """Add default langchain tools. Return true if successful, false otherwise.""" success = True tool_creates = ToolCreate.load_default_langchain_tools() @@ -1736,57 +1335,37 @@ class SyncServer(Server): return success - def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]: - """Get a single message from the agent's memory""" - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) - message = letta_agent.message_manager.get_message_by_id(id=message_id) - save_agent(letta_agent, self.ms) - return message - - def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate) -> Message: + def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate, actor: User) -> Message: """Update the details of a message associated with an agent""" # Get the current message - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) response = letta_agent.update_message(message_id=message_id, request=request) - save_agent(letta_agent, self.ms) + save_agent(letta_agent) return response - def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message: + def rewrite_agent_message(self, agent_id: str, new_text: str, actor: User) -> Message: # Get the current message - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) response = letta_agent.rewrite_message(new_text=new_text) - save_agent(letta_agent, self.ms) + save_agent(letta_agent) return response - def rethink_agent_message(self, agent_id: str, new_thought: str) -> Message: - + def rethink_agent_message(self, agent_id: str, new_thought: str, actor: User) -> Message: # Get the current message - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) response = letta_agent.rethink_message(new_thought=new_thought) - save_agent(letta_agent, self.ms) + save_agent(letta_agent) return response - def retry_agent_message(self, agent_id: str) -> List[Message]: - + def retry_agent_message(self, agent_id: str, actor: User) -> List[Message]: # Get the current message - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) response = letta_agent.retry_message() - save_agent(letta_agent, self.ms) + save_agent(letta_agent) return response - def get_user_or_default(self, user_id: Optional[str]) -> PydanticUser: - """Get the user object for user_id if it exists, otherwise return the default user object""" - if user_id is None: - user_id = self.user_manager.DEFAULT_USER_ID - - try: - return self.user_manager.get_user_by_id(user_id=user_id) - except NoResultFound: - raise HTTPException(status_code=404, detail=f"User with id {user_id} not found") - def get_organization_or_default(self, org_id: Optional[str]) -> Organization: """Get the organization object for org_id if it exists, otherwise return the default organization object""" if org_id is None: @@ -1829,100 +1408,13 @@ class SyncServer(Server): user_id: str, agent_id: str, ) -> ContextWindowOverview: + # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user + actor = self.user_manager.get_user_or_default(user_id=user_id) + # Get the current message - letta_agent = self.load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) return letta_agent.get_context_window() - def link_block_to_agent_memory(self, user_id: str, agent_id: str, block_id: str) -> Memory: - """Link a block to an agent's memory""" - block = self.block_manager.get_block_by_id(block_id=block_id, actor=self.user_manager.get_user_by_id(user_id=user_id)) - if block is None: - raise ValueError(f"Block with id {block_id} not found") - self.blocks_agents_manager.add_block_to_agent(agent_id, block_id, block_label=block.label) - - # get agent memory - memory = self.get_agent(agent_id=agent_id).memory - return memory - - def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_label: str, delete_if_no_ref: bool = True) -> Memory: - """Unlink a block from an agent's memory. If the block is not linked to any agent, delete it.""" - self.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=agent_id, block_label=block_label) - - # get agent memory - memory = self.get_agent(agent_id=agent_id).memory - return memory - - def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory: - """Update the limit of a block in an agent's memory""" - block = self.get_agent_block_by_label(user_id=user_id, agent_id=agent_id, label=block_label) - self.block_manager.update_block( - block_id=block.id, block_update=BlockUpdate(limit=limit), actor=self.user_manager.get_user_by_id(user_id=user_id) - ) - # get agent memory - memory = self.get_agent(agent_id=agent_id).memory - return memory - - def upate_block(self, user_id: str, block_id: str, block_update: BlockUpdate) -> Block: - """Update a block""" - return self.block_manager.update_block( - block_id=block_id, block_update=block_update, actor=self.user_manager.get_user_by_id(user_id=user_id) - ) - - def get_agent_block_by_label(self, user_id: str, agent_id: str, label: str) -> Block: - """Get a block by label""" - # TODO: implement at ORM? - for block_id in self.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id): - block = self.block_manager.get_block_by_id(block_id=block_id, actor=self.user_manager.get_user_by_id(user_id=user_id)) - if block.label == label: - return block - return None - - # def run_tool(self, tool_id: str, tool_args: str, user_id: str) -> FunctionReturn: - # """Run a tool using the sandbox and return the result""" - - # try: - # tool_args_dict = json.loads(tool_args) - # except json.JSONDecodeError: - # raise ValueError("Invalid JSON string for tool_args") - - # # Get the tool by ID - # user = self.user_manager.get_user_by_id(user_id=user_id) - # tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user) - # if tool.name is None: - # raise ValueError(f"Tool with id {tool_id} does not have a name") - - # # TODO eventually allow using agent state in tools - # agent_state = None - - # try: - # sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id).run(agent_state=agent_state) - # if sandbox_run_result is None: - # raise ValueError(f"Tool with id {tool_id} returned execution with None") - # function_response = str(sandbox_run_result.func_return) - - # return FunctionReturn( - # id="null", - # function_call_id="null", - # date=get_utc_time(), - # status="success", - # function_return=function_response, - # ) - # except Exception as e: - # # same as agent.py - # from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT - - # error_msg = f"Error executing tool {tool.name}: {e}" - # if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: - # error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] - - # return FunctionReturn( - # id="null", - # function_call_id="null", - # date=get_utc_time(), - # status="error", - # function_return=error_msg, - # ) - def run_tool_from_source( self, user_id: str, @@ -1994,7 +1486,6 @@ class SyncServer(Server): stderr=[traceback.format_exc()], ) - def get_error_msg_for_func_return(self, tool_name, exception_message): # same as agent.py from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT @@ -2004,7 +1495,6 @@ class SyncServer(Server): error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] return error_msg - # Composio wrappers def get_composio_client(self, api_key: Optional[str] = None): if api_key: diff --git a/letta/server/ws_api/server.py b/letta/server/ws_api/server.py index f2ec4f99..975bd0d2 100644 --- a/letta/server/ws_api/server.py +++ b/letta/server/ws_api/server.py @@ -19,11 +19,6 @@ class WebSocketServer: self.server = SyncServer(default_interface=self.interface) def shutdown_server(self): - try: - self.server.save_agents() - print(f"Saved agents") - except Exception as e: - print(f"Saving agents failed with: {e}") try: self.interface.close() print(f"Closed the WS interface") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py new file mode 100644 index 00000000..093831aa --- /dev/null +++ b/letta/services/agent_manager.py @@ -0,0 +1,405 @@ +from typing import Dict, List, Optional + +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.orm import Agent as AgentModel +from letta.orm import Block as BlockModel +from letta.orm import Source as SourceModel +from letta.orm import Tool as ToolModel +from letta.orm.errors import NoResultFound +from letta.schemas.agent import AgentState as PydanticAgentState +from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent +from letta.schemas.block import Block as PydanticBlock +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.source import Source as PydanticSource +from letta.schemas.tool_rule import ToolRule as PydanticToolRule +from letta.schemas.user import User as PydanticUser +from letta.services.block_manager import BlockManager +from letta.services.helpers.agent_manager_helper import ( + _process_relationship, + _process_tags, + derive_system_message, +) +from letta.services.passage_manager import PassageManager +from letta.services.source_manager import SourceManager +from letta.services.tool_manager import ToolManager +from letta.utils import enforce_types + + +# Agent Manager Class +class AgentManager: + """Manager class to handle business logic related to Agents.""" + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + self.block_manager = BlockManager() + self.tool_manager = ToolManager() + self.source_manager = SourceManager() + + # ====================================================================================================================== + # Basic CRUD operations + # ====================================================================================================================== + @enforce_types + def create_agent( + self, + agent_create: CreateAgent, + actor: PydanticUser, + ) -> PydanticAgentState: + system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system) + + # create blocks (note: cannot be linked into the agent_id is created) + block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original + for create_block in agent_create.memory_blocks: + block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor) + block_ids.append(block.id) + + # TODO: Remove this block once we deprecate the legacy `tools` field + # create passed in `tools` + tool_names = [] + if agent_create.include_base_tools: + tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS) + if agent_create.tools: + tool_names.extend(agent_create.tools) + + tool_ids = agent_create.tool_ids or [] + for tool_name in tool_names: + tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + if tool: + tool_ids.append(tool.id) + # Remove duplicates + tool_ids = list(set(tool_ids)) + + return self._create_agent( + name=agent_create.name, + system=system, + agent_type=agent_create.agent_type, + llm_config=agent_create.llm_config, + embedding_config=agent_create.embedding_config, + block_ids=block_ids, + tool_ids=tool_ids, + source_ids=agent_create.source_ids or [], + tags=agent_create.tags or [], + description=agent_create.description, + metadata_=agent_create.metadata_, + tool_rules=agent_create.tool_rules, + actor=actor, + ) + + @enforce_types + def _create_agent( + self, + actor: PydanticUser, + name: str, + system: str, + agent_type: AgentType, + llm_config: LLMConfig, + embedding_config: EmbeddingConfig, + block_ids: List[str], + tool_ids: List[str], + source_ids: List[str], + tags: List[str], + description: Optional[str] = None, + metadata_: Optional[Dict] = None, + tool_rules: Optional[List[PydanticToolRule]] = None, + ) -> PydanticAgentState: + """Create a new agent.""" + with self.session_maker() as session: + # Prepare the agent data + data = { + "name": name, + "system": system, + "agent_type": agent_type, + "llm_config": llm_config, + "embedding_config": embedding_config, + "organization_id": actor.organization_id, + "description": description, + "metadata_": metadata_, + "tool_rules": tool_rules, + } + + # Create the new agent using SqlalchemyBase.create + new_agent = AgentModel(**data) + _process_relationship(session, new_agent, "tools", ToolModel, tool_ids, replace=True) + _process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True) + _process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True) + _process_tags(new_agent, tags, replace=True) + new_agent.create(session, actor=actor) + + # Convert to PydanticAgentState and return + return new_agent.to_pydantic() + + @enforce_types + def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState: + """ + Update an existing agent. + + Args: + agent_id: The ID of the agent to update. + agent_update: UpdateAgent object containing the updated fields. + actor: User performing the action. + + Returns: + PydanticAgentState: The updated agent as a Pydantic model. + """ + with self.session_maker() as session: + # Retrieve the existing agent + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # Update scalar fields directly + scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"} + for field in scalar_fields: + value = getattr(agent_update, field, None) + if value is not None: + setattr(agent, field, value) + + # Update relationships using _process_relationship and _process_tags + if agent_update.tool_ids is not None: + _process_relationship(session, agent, "tools", ToolModel, agent_update.tool_ids, replace=True) + if agent_update.source_ids is not None: + _process_relationship(session, agent, "sources", SourceModel, agent_update.source_ids, replace=True) + if agent_update.block_ids is not None: + _process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True) + if agent_update.tags is not None: + _process_tags(agent, agent_update.tags, replace=True) + + # Commit and refresh the agent + agent.update(session, actor=actor) + + # Convert to PydanticAgentState and return + return agent.to_pydantic() + + @enforce_types + def list_agents( + self, + actor: PydanticUser, + tags: Optional[List[str]] = None, + match_all_tags: bool = False, + cursor: Optional[str] = None, + limit: Optional[int] = 50, + **kwargs, + ) -> List[PydanticAgentState]: + """ + List agents that have the specified tags. + """ + with self.session_maker() as session: + agents = AgentModel.list( + db_session=session, + tags=tags, + match_all_tags=match_all_tags, + cursor=cursor, + limit=limit, + organization_id=actor.organization_id if actor else None, + **kwargs, + ) + + return [agent.to_pydantic() for agent in agents] + + @enforce_types + def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: + """Fetch an agent by its ID.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + return agent.to_pydantic() + + @enforce_types + def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: + """Fetch an agent by its ID.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) + return agent.to_pydantic() + + @enforce_types + def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Deletes an agent and its associated relationships. + Ensures proper permission checks and cascades where applicable. + + Args: + agent_id: ID of the agent to be deleted. + actor: User performing the action. + + Returns: + PydanticAgentState: The deleted agent state + """ + with self.session_maker() as session: + # Retrieve the agent + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # TODO: @mindy delete this piece when we have a proper passages/sources implementation + # TODO: This is done very hacky on purpose + # TODO: 1000 limit is also wack + passage_manager = PassageManager() + passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000) + + agent_state = agent.to_pydantic() + agent.hard_delete(session) + return agent_state + + # ====================================================================================================================== + # Source Management + # ====================================================================================================================== + @enforce_types + def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None: + """ + Attaches a source to an agent. + + Args: + agent_id: ID of the agent to attach the source to + source_id: ID of the source to attach + actor: User performing the action + + Raises: + ValueError: If either agent or source doesn't exist + IntegrityError: If the source is already attached to the agent + """ + with self.session_maker() as session: + # Verify both agent and source exist and user has permission to access them + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # The _process_relationship helper already handles duplicate checking via unique constraint + _process_relationship( + session=session, + agent=agent, + relationship_name="sources", + model_class=SourceModel, + item_ids=[source_id], + allow_partial=False, + replace=False, # Extend existing sources rather than replace + ) + + # Commit the changes + agent.update(session, actor=actor) + + @enforce_types + def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: + """ + Lists all sources attached to an agent. + + Args: + agent_id: ID of the agent to list sources for + actor: User performing the action + + Returns: + List[str]: List of source IDs attached to the agent + """ + with self.session_maker() as session: + # Verify agent exists and user has permission to access it + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # Use the lazy-loaded relationship to get sources + return [source.to_pydantic() for source in agent.sources] + + @enforce_types + def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None: + """ + Detaches a source from an agent. + + Args: + agent_id: ID of the agent to detach the source from + source_id: ID of the source to detach + actor: User performing the action + """ + with self.session_maker() as session: + # Verify agent exists and user has permission to access it + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # Remove the source from the relationship + agent.sources = [s for s in agent.sources if s.id != source_id] + + # Commit the changes + agent.update(session, actor=actor) + + # ====================================================================================================================== + # Block management + # ====================================================================================================================== + @enforce_types + def get_block_with_label( + self, + agent_id: str, + block_label: str, + actor: PydanticUser, + ) -> PydanticBlock: + """Gets a block attached to an agent by its label.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + for block in agent.core_memory: + if block.label == block_label: + return block.to_pydantic() + raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") + + @enforce_types + def update_block_with_label( + self, + agent_id: str, + block_label: str, + new_block_id: str, + actor: PydanticUser, + ) -> PydanticAgentState: + """Updates which block is assigned to a specific label for an agent.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor) + + if new_block.label != block_label: + raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'") + + # Remove old block with this label if it exists + agent.core_memory = [b for b in agent.core_memory if b.label != block_label] + + # Add new block + agent.core_memory.append(new_block) + agent.update(session, actor=actor) + return agent.to_pydantic() + + @enforce_types + def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: + """Attaches a block to an agent.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + + agent.core_memory.append(block) + agent.update(session, actor=actor) + return agent.to_pydantic() + + @enforce_types + def detach_block( + self, + agent_id: str, + block_id: str, + actor: PydanticUser, + ) -> PydanticAgentState: + """Detaches a block from an agent.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + original_length = len(agent.core_memory) + + agent.core_memory = [b for b in agent.core_memory if b.id != block_id] + + if len(agent.core_memory) == original_length: + raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'") + + agent.update(session, actor=actor) + return agent.to_pydantic() + + @enforce_types + def detach_block_with_label( + self, + agent_id: str, + block_label: str, + actor: PydanticUser, + ) -> PydanticAgentState: + """Detaches a block with the specified label from an agent.""" + with self.session_maker() as session: + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + original_length = len(agent.core_memory) + + agent.core_memory = [b for b in agent.core_memory if b.label != block_label] + + if len(agent.core_memory) == original_length: + raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'") + + agent.update(session, actor=actor) + return agent.to_pydantic() diff --git a/letta/services/agents_tags_manager.py b/letta/services/agents_tags_manager.py deleted file mode 100644 index f84ea11b..00000000 --- a/letta/services/agents_tags_manager.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import List - -from letta.orm.agents_tags import AgentsTags as AgentsTagsModel -from letta.orm.errors import NoResultFound -from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags -from letta.schemas.user import User as PydanticUser -from letta.utils import enforce_types - - -class AgentsTagsManager: - """Manager class to handle business logic related to Tags.""" - - def __init__(self): - from letta.server.server import db_context - - self.session_maker = db_context - - @enforce_types - def add_tag_to_agent(self, agent_id: str, tag: str, actor: PydanticUser) -> PydanticAgentsTags: - """Add a tag to an agent.""" - with self.session_maker() as session: - # Check if the tag already exists for this agent - try: - agents_tags_model = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor) - return agents_tags_model.to_pydantic() - except NoResultFound: - agents_tags = PydanticAgentsTags(agent_id=agent_id, tag=tag).model_dump(exclude_none=True) - new_tag = AgentsTagsModel(**agents_tags, organization_id=actor.organization_id) - new_tag.create(session, actor=actor) - return new_tag.to_pydantic() - - @enforce_types - def delete_all_tags_from_agent(self, agent_id: str, actor: PydanticUser): - """Delete a tag from an agent. This is a permanent hard delete.""" - tags = self.get_tags_for_agent(agent_id=agent_id, actor=actor) - for tag in tags: - self.delete_tag_from_agent(agent_id=agent_id, tag=tag, actor=actor) - - @enforce_types - def delete_tag_from_agent(self, agent_id: str, tag: str, actor: PydanticUser): - """Delete a tag from an agent.""" - with self.session_maker() as session: - try: - # Retrieve and delete the tag association - tag_association = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor) - tag_association.hard_delete(session, actor=actor) - except NoResultFound: - raise ValueError(f"Tag '{tag}' not found for agent '{agent_id}'.") - - @enforce_types - def get_agents_by_tag(self, tag: str, actor: PydanticUser) -> List[str]: - """Retrieve all agent IDs associated with a specific tag.""" - with self.session_maker() as session: - # Query for all agents with the given tag - agents_with_tag = AgentsTagsModel.list(db_session=session, tag=tag, organization_id=actor.organization_id) - return [record.agent_id for record in agents_with_tag] - - @enforce_types - def get_tags_for_agent(self, agent_id: str, actor: PydanticUser) -> List[str]: - """Retrieve all tags associated with a specific agent.""" - with self.session_maker() as session: - # Query for all tags associated with the given agent - tags_for_agent = AgentsTagsModel.list(db_session=session, agent_id=agent_id, organization_id=actor.organization_id) - return [record.tag for record in tags_for_agent] diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 65f6c79e..77eb5e7e 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -7,7 +7,6 @@ from letta.schemas.block import Block from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, Human, Persona from letta.schemas.user import User as PydanticUser -from letta.services.blocks_agents_manager import BlocksAgentsManager from letta.utils import enforce_types, list_human_files, list_persona_files @@ -37,33 +36,17 @@ class BlockManager: @enforce_types def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" - # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents - blocks_agents_manager = BlocksAgentsManager() - agent_ids = [] - if block_update.label: - agent_ids = blocks_agents_manager.list_agent_ids_with_block(block_id=block_id) - for agent_id in agent_ids: - blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id) + # Safety check for block with self.session_maker() as session: - # Update block block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): setattr(block, key, value) - try: - block.to_pydantic() - except Exception as e: - # invalid pydantic model - raise ValueError(f"Failed to create pydantic model: {e}") + block.update(db_session=session, actor=actor) - - # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents - if block_update.label: - for agent_id in agent_ids: - blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block_update.label) - - return block.to_pydantic() + return block.to_pydantic() @enforce_types def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: @@ -111,6 +94,15 @@ class BlockManager: except NoResultFound: return None + @enforce_types + def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: + # TODO: We can do this much more efficiently by listing, instead of executing individual queries per block_id + blocks = [] + for block_id in block_ids: + block = self.get_block_by_id(block_id, actor=actor) + blocks.append(block) + return blocks + @enforce_types def add_default_blocks(self, actor: PydanticUser): for persona_file in list_persona_files(): diff --git a/letta/services/blocks_agents_manager.py b/letta/services/blocks_agents_manager.py deleted file mode 100644 index 121db586..00000000 --- a/letta/services/blocks_agents_manager.py +++ /dev/null @@ -1,106 +0,0 @@ -import warnings -from typing import List - -from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel -from letta.orm.errors import NoResultFound -from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents -from letta.utils import enforce_types - - -# TODO: DELETE THIS ASAP -# TODO: So we have a patch where we manually specify CRUD operations -# TODO: This is because Agent is NOT migrated to the ORM yet -# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers -class BlocksAgentsManager: - """Manager class to handle business logic related to Blocks and Agents.""" - - def __init__(self): - from letta.server.server import db_context - - self.session_maker = db_context - - @enforce_types - def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents: - """Add a block to an agent. If the label already exists on that agent, this will error.""" - with self.session_maker() as session: - try: - # Check if the block-label combination already exists for this agent - blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) - warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.") - except NoResultFound: - blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label) - blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True)) - blocks_agents_record.create(session) - - return blocks_agents_record.to_pydantic() - - @enforce_types - def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents: - """Remove a block with a label from an agent.""" - with self.session_maker() as session: - try: - # Find and delete the block-label association for the agent - blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) - blocks_agents_record.hard_delete(session) - return blocks_agents_record.to_pydantic() - except NoResultFound: - raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.") - - @enforce_types - def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents: - """Remove a block with a label from an agent.""" - with self.session_maker() as session: - try: - # Find and delete the block-label association for the agent - blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id) - blocks_agents_record.hard_delete(session) - return blocks_agents_record.to_pydantic() - except NoResultFound: - raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.") - - @enforce_types - def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents: - """Update the block ID for a specific block label for an agent.""" - with self.session_maker() as session: - try: - blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) - blocks_agents_record.block_id = new_block_id - return blocks_agents_record.to_pydantic() - except NoResultFound: - raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.") - - @enforce_types - def list_block_ids_for_agent(self, agent_id: str) -> List[str]: - """List all block ids associated with a specific agent.""" - with self.session_maker() as session: - blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id) - return [record.block_id for record in blocks_agents_record] - - @enforce_types - def list_block_labels_for_agent(self, agent_id: str) -> List[str]: - """List all block labels associated with a specific agent.""" - with self.session_maker() as session: - blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id) - return [record.block_label for record in blocks_agents_record] - - @enforce_types - def list_agent_ids_with_block(self, block_id: str) -> List[str]: - """List all agents associated with a specific block.""" - with self.session_maker() as session: - blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id) - return [record.agent_id for record in blocks_agents_record] - - @enforce_types - def get_block_id_for_label(self, agent_id: str, block_label: str) -> str: - """Get the block ID for a specific block label for an agent.""" - with self.session_maker() as session: - try: - blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) - return blocks_agents_record.block_id - except NoResultFound: - raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.") - - @enforce_types - def remove_all_agent_blocks(self, agent_id: str): - for block_id in self.list_block_ids_for_agent(agent_id): - self.remove_block_with_id_from_agent(agent_id, block_id) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py new file mode 100644 index 00000000..95ad26be --- /dev/null +++ b/letta/services/helpers/agent_manager_helper.py @@ -0,0 +1,90 @@ +from typing import List, Optional + +from letta.orm.agent import Agent as AgentModel +from letta.orm.agents_tags import AgentsTags +from letta.orm.errors import NoResultFound +from letta.prompts import gpt_system +from letta.schemas.agent import AgentType + + +# Static methods +def _process_relationship( + session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True +): + """ + Generalized function to handle relationships like tools, sources, and blocks using item IDs. + + Args: + session: The database session. + agent: The AgentModel instance. + relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources'). + model_class: The ORM class corresponding to the related items. + item_ids: List of IDs to set or update. + allow_partial: If True, allows missing items without raising errors. + replace: If True, replaces the entire relationship; otherwise, extends it. + + Raises: + ValueError: If `allow_partial` is False and some IDs are missing. + """ + current_relationship = getattr(agent, relationship_name, []) + if not item_ids: + if replace: + setattr(agent, relationship_name, []) + return + + # Retrieve models for the provided IDs + found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all() + + # Validate all items are found if allow_partial is False + if not allow_partial and len(found_items) != len(item_ids): + missing = set(item_ids) - {item.id for item in found_items} + raise NoResultFound(f"Items not found in {relationship_name}: {missing}") + + if replace: + # Replace the relationship + setattr(agent, relationship_name, found_items) + else: + # Extend the relationship (only add new items) + current_ids = {item.id for item in current_relationship} + new_items = [item for item in found_items if item.id not in current_ids] + current_relationship.extend(new_items) + + +def _process_tags(agent: AgentModel, tags: List[str], replace=True): + """ + Handles tags for an agent. + + Args: + agent: The AgentModel instance. + tags: List of tags to set or update. + replace: If True, replaces all tags; otherwise, extends them. + """ + if not tags: + if replace: + agent.tags = [] + return + + # Ensure tags are unique and prepare for replacement/extension + new_tags = {AgentsTags(agent_id=agent.id, tag=tag) for tag in set(tags)} + if replace: + agent.tags = list(new_tags) + else: + existing_tags = {t.tag for t in agent.tags} + agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags]) + + +def derive_system_message(agent_type: AgentType, system: Optional[str] = None): + if system is None: + # TODO: don't hardcode + if agent_type == AgentType.memgpt_agent: + system = gpt_system.get_system_text("memgpt_chat") + elif agent_type == AgentType.o1_agent: + system = gpt_system.get_system_text("memgpt_modified_o1") + elif agent_type == AgentType.offline_memory_agent: + system = gpt_system.get_system_text("memgpt_offline_memory") + elif agent_type == AgentType.chat_only_agent: + system = gpt_system.get_system_text("memgpt_convo_only") + else: + raise ValueError(f"Invalid agent type: {agent_type}") + + return system diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index ef93b732..100a4433 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,25 +1,25 @@ -from typing import List, Optional, Dict, Tuple -from letta.constants import MAX_EMBEDDING_DIM from datetime import datetime +from typing import List, Optional + import numpy as np -from letta.orm.errors import NoResultFound -from letta.utils import enforce_types - +from letta.constants import MAX_EMBEDDING_DIM from letta.embeddings import embedding_model, parse_and_chunk_text -from letta.schemas.embedding_config import EmbeddingConfig - +from letta.orm.errors import NoResultFound from letta.orm.passage import Passage as PassageModel -from letta.orm.sqlalchemy_base import AccessType from letta.schemas.agent import AgentState +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types + class PassageManager: """Manager class to handle business logic related to Passages.""" def __init__(self): from letta.server.server import db_context + self.session_maker = db_context @enforce_types @@ -43,20 +43,20 @@ class PassageManager: return [self.create_passage(p, actor) for p in passages] @enforce_types - def insert_passage(self, + def insert_passage( + self, agent_state: AgentState, agent_id: str, - text: str, - actor: PydanticUser, - return_ids: bool = False + text: str, + actor: PydanticUser, ) -> List[PydanticPassage]: - """ Insert passage(s) into archival memory """ + """Insert passage(s) into archival memory""" embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size embed_model = embedding_model(agent_state.embedding_config) passages = [] - + try: # breakup string into passages for text in parse_and_chunk_text(text, embedding_chunk_size): @@ -75,12 +75,12 @@ class PassageManager: agent_id=agent_id, text=text, embedding=embedding, - embedding_config=agent_state.embedding_config + embedding_config=agent_state.embedding_config, ), - actor=actor + actor=actor, ) passages.append(passage) - + return passages except Exception as e: @@ -125,20 +125,21 @@ class PassageManager: raise ValueError(f"Passage with id {passage_id} not found.") @enforce_types - def list_passages(self, - actor : PydanticUser, - agent_id : Optional[str] = None, - file_id : Optional[str] = None, - cursor : Optional[str] = None, - limit : Optional[int] = 50, - query_text : Optional[str] = None, - start_date : Optional[datetime] = None, - end_date : Optional[datetime] = None, - ascending : bool = True, - source_id : Optional[str] = None, - embed_query : bool = False, - embedding_config: Optional[EmbeddingConfig] = None - ) -> List[PydanticPassage]: + def list_passages( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + cursor: Optional[str] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ascending: bool = True, + source_id: Optional[str] = None, + embed_query: bool = False, + embedding_config: Optional[EmbeddingConfig] = None, + ) -> List[PydanticPassage]: """List passages with pagination.""" with self.session_maker() as session: filters = {"organization_id": actor.organization_id} @@ -148,7 +149,7 @@ class PassageManager: filters["file_id"] = file_id if source_id: filters["source_id"] = source_id - + embedded_text = None if embed_query: assert embedding_config is not None @@ -161,7 +162,7 @@ class PassageManager: embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() results = PassageModel.list( - db_session=session, + db_session=session, cursor=cursor, start_date=start_date, end_date=end_date, @@ -169,17 +170,12 @@ class PassageManager: ascending=ascending, query_text=query_text if not embedded_text else None, query_embedding=embedded_text, - **filters + **filters, ) return [p.to_pydantic() for p in results] - + @enforce_types - def size( - self, - actor : PydanticUser, - agent_id : Optional[str] = None, - **kwargs - ) -> int: + def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int: """Get the total count of messages with optional filters. Args: @@ -189,28 +185,32 @@ class PassageManager: with self.session_maker() as session: return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs) - def delete_passages(self, - actor: PydanticUser, - agent_id: Optional[str] = None, - file_id: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - limit: Optional[int] = 50, - cursor: Optional[str] = None, - query_text: Optional[str] = None, - source_id: Optional[str] = None - ) -> bool: - + def delete_passages( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + limit: Optional[int] = 50, + cursor: Optional[str] = None, + query_text: Optional[str] = None, + source_id: Optional[str] = None, + ) -> bool: + passages = self.list_passages( - actor=actor, - agent_id=agent_id, - file_id=file_id, - cursor=cursor, + actor=actor, + agent_id=agent_id, + file_id=file_id, + cursor=cursor, limit=limit, - start_date=start_date, - end_date=end_date, - query_text=query_text, - source_id=source_id) - + start_date=start_date, + end_date=end_date, + query_text=query_text, + source_id=source_id, + ) + + # TODO: This is very inefficient + # TODO: We should have a base `delete_all_matching_filters`-esque function for passage in passages: self.delete_passage_by_id(passage_id=passage.id, actor=actor) diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index a6745cec..a5804347 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -3,6 +3,7 @@ from typing import List, Optional from letta.orm.errors import NoResultFound from letta.orm.file import FileMetadata as FileMetadataModel from letta.orm.source import Source as SourceModel +from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate @@ -60,7 +61,7 @@ class SourceManager: """Delete a source by its ID.""" with self.session_maker() as session: source = SourceModel.read(db_session=session, identifier=source_id) - source.delete(db_session=session, actor=actor) + source.hard_delete(db_session=session, actor=actor) return source.to_pydantic() @enforce_types @@ -76,6 +77,26 @@ class SourceManager: ) return [source.to_pydantic() for source in sources] + @enforce_types + def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]: + """ + Lists all agents that have the specified source attached. + + Args: + source_id: ID of the source to find attached agents for + actor: User performing the action (optional for now, following existing pattern) + + Returns: + List[PydanticAgentState]: List of agents that have this source attached + """ + with self.session_maker() as session: + # Verify source exists and user has permission to access it + source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + + # The agents relationship is already loaded due to lazy="selectin" in the Source model + # and will be properly filtered by organization_id due to the OrganizationMixin + return [agent.to_pydantic() for agent in source.agents] + # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: diff --git a/letta/services/tools_agents_manager.py b/letta/services/tools_agents_manager.py deleted file mode 100644 index 35b24e5a..00000000 --- a/letta/services/tools_agents_manager.py +++ /dev/null @@ -1,94 +0,0 @@ -import warnings -from typing import List, Optional - -from sqlalchemy import select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from letta.orm.errors import NoResultFound -from letta.orm.organization import Organization -from letta.orm.tool import Tool -from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel -from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents - -class ToolsAgentsManager: - """Manages the relationship between tools and agents.""" - - def __init__(self): - from letta.server.server import db_context - self.session_maker = db_context - - def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents: - """Add a tool to an agent. - - When a tool is added to an agent, it will be added to all agents in the same organization. - """ - with self.session_maker() as session: - try: - # Check if the tool-agent combination already exists for this agent - tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name) - warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.") - except NoResultFound: - tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name) - tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True)) - tools_agents_record.create(session) - - return tools_agents_record.to_pydantic() - - def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None: - """Remove a tool from an agent by its name. - - When a tool is removed from an agent, it will be removed from all agents in the same organization. - """ - with self.session_maker() as session: - try: - # Find and delete the tool-agent association for the agent - tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name) - tools_agents_record.hard_delete(session) - return tools_agents_record.to_pydantic() - except NoResultFound: - raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.") - - def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents: - """Remove a tool with an ID from an agent.""" - with self.session_maker() as session: - try: - tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id) - tools_agents_record.hard_delete(session) - return tools_agents_record.to_pydantic() - except NoResultFound: - raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.") - - def list_tool_ids_for_agent(self, agent_id: str) -> List[str]: - """List all tool IDs associated with a specific agent.""" - with self.session_maker() as session: - tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id) - return [record.tool_id for record in tools_agents_record] - - def list_tool_names_for_agent(self, agent_id: str) -> List[str]: - """List all tool names associated with a specific agent.""" - with self.session_maker() as session: - tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id) - return [record.tool_name for record in tools_agents_record] - - def list_agent_ids_with_tool(self, tool_id: str) -> List[str]: - """List all agents associated with a specific tool.""" - with self.session_maker() as session: - tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id) - return [record.agent_id for record in tools_agents_record] - - def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str: - """Get the tool ID for a specific tool name for an agent.""" - with self.session_maker() as session: - try: - tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name) - return tools_agents_record.tool_id - except NoResultFound: - raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.") - - def remove_all_agent_tools(self, agent_id: str) -> None: - """Remove all tools associated with an agent.""" - with self.session_maker() as session: - tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id) - for record in tools_agents_records: - record.hard_delete(session) \ No newline at end of file diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index cc99ad8c..5dca0fff 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -73,12 +73,6 @@ class UserManager: user = UserModel.read(db_session=session, identifier=user_id) user.hard_delete(session) - # TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping - # Cascade delete for related models: Agent, Source, AgentSourceMapping - # session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() - # session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() - # session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() - session.commit() @enforce_types @@ -93,6 +87,17 @@ class UserManager: """Fetch the default user.""" return self.get_user_by_id(self.DEFAULT_USER_ID) + @enforce_types + def get_user_or_default(self, user_id: Optional[str] = None): + """Fetch the user or default user.""" + if not user_id: + return self.get_default_user() + + try: + return self.get_user_by_id(user_id=user_id) + except NoResultFound: + return self.get_default_user() + @enforce_types def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]: """List users with pagination using cursor (id) and limit.""" diff --git a/letta/utils.py b/letta/utils.py index ad666885..7184a0e3 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -548,13 +548,13 @@ def enforce_types(func): for arg_name, arg_value in args_with_hints.items(): hint = hints.get(arg_name) if hint and not matches_type(arg_value, hint): - raise ValueError(f"Argument {arg_name} does not match type {hint}") + raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}") # Check types of keyword arguments for arg_name, arg_value in kwargs.items(): hint = hints.get(arg_name) if hint and not matches_type(arg_value, hint): - raise ValueError(f"Argument {arg_name} does not match type {hint}") + raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}") return func(*args, **kwargs) diff --git a/locust_test.py b/locust_test.py index 1e74d405..570e6eef 100644 --- a/locust_test.py +++ b/locust_test.py @@ -4,7 +4,7 @@ import string from locust import HttpUser, between, task from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA -from letta.schemas.agent import CreateAgent, PersistedAgentState +from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.letta_request import LettaRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ChatMemory @@ -49,7 +49,7 @@ class LettaUser(HttpUser): response.failure(f"Failed to create agent: {response.text}") response_json = response.json() - agent_state = PersistedAgentState(**response_json) + agent_state = AgentState(**response_json) self.agent_id = agent_state.id print("Created agent", self.agent_id, agent_state.name) diff --git a/scripts/migrate_0.3.18.py b/scripts/migrate_0.3.18.py deleted file mode 100644 index 3464ea12..00000000 --- a/scripts/migrate_0.3.18.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import uuid - -from sqlalchemy import MetaData, Table, create_engine - -from letta import create_client -from letta.config import LettaConfig -from letta.data_types import AgentState, EmbeddingConfig, LLMConfig -from letta.metadata import MetadataStore -from letta.presets.presets import add_default_tools -from letta.prompts import gpt_system - -# Replace this with your actual database connection URL -config = LettaConfig.load() -if config.recall_storage_type == "sqlite": - DATABASE_URL = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") -else: - DATABASE_URL = config.recall_storage_uri -print(DATABASE_URL) -engine = create_engine(DATABASE_URL) -metadata = MetaData() - -# defaults -system_prompt = gpt_system.get_system_text("memgpt_chat") - -# Reflect the existing table -table = Table("agents", metadata, autoload_with=engine) - - -# get all agent rows -agent_states = [] -with engine.connect() as conn: - agents = conn.execute(table.select()).fetchall() - for agent in agents: - id = uuid.UUID(agent[0]) - user_id = uuid.UUID(agent[1]) - name = agent[2] - print(f"Migrating agent {name}") - persona = agent[3] - human = agent[4] - system = agent[5] - preset = agent[6] - created_at = agent[7] - llm_config = LLMConfig(**agent[8]) - embedding_config = EmbeddingConfig(**agent[9]) - state = agent[10] - tools = agent[11] - - state["memory"] = {"human": {"value": human, "limit": 2000}, "persona": {"value": persona, "limit": 2000}} - - agent_state = AgentState( - id=id, - user_id=user_id, - name=name, - system=system, - created_at=created_at, - llm_config=llm_config, - embedding_config=embedding_config, - state=state, - tools=tools, - _metadata={"human": "migrated", "persona": "migrated"}, - ) - - agent_states.append(agent_state) - -# remove agents table -agents_model = Table("agents", metadata, autoload_with=engine) -agents_model.drop(engine) - -# remove tool table -tool_model = Table("toolmodel", metadata, autoload_with=engine) -tool_model.drop(engine) - -# re-create tables and add default tools -ms = MetadataStore(config) -add_default_tools(None, ms) -print("Tools", [tool.name for tool in ms.list_tools()]) - - -for agent in agent_states: - ms.create_agent(agent) - print(f"Agent {agent.name} migrated successfully!") - -# add another agent to create core memory tool -client = create_client() -dummy_agent = client.create_agent(name="dummy_agent") -tools = client.list_tools() -assert "core_memory_append" in [tool.name for tool in tools] - -print("Migration completed successfully!") diff --git a/tests/conftest.py b/tests/conftest.py index 899a74af..17ae8ef9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,6 @@ import logging import pytest -from letta.settings import tool_settings - def pytest_configure(config): logging.basicConfig(level=logging.DEBUG) @@ -11,6 +9,8 @@ def pytest_configure(config): @pytest.fixture def mock_e2b_api_key_none(): + from letta.settings import tool_settings + # Store the original value of e2b_api_key original_api_key = tool_settings.e2b_api_key diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 47424572..fbe7ffb6 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -61,7 +61,7 @@ def setup_agent( filename: str, memory_human_str: str = get_human_text(DEFAULT_HUMAN), memory_persona_str: str = get_persona_text(DEFAULT_PERSONA), - tools: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, agent_uuid: str = agent_uuid, ) -> AgentState: @@ -77,7 +77,7 @@ def setup_agent( memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) agent_state = client.create_agent( - name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools, tool_rules=tool_rules + name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules ) return agent_state @@ -103,7 +103,6 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet cleanup(client=client, agent_uuid=agent_uuid) agent_state = setup_agent(client, filename) - tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tool_names] full_agent_state = client.get_agent(agent_state.id) agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) @@ -171,19 +170,18 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse: client = create_client() cleanup(client=client, agent_uuid=agent_uuid) tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - tool_name = tool.name # Set up persona for tool usage persona = f""" My name is Letta. - I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question. + I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool.name} which will search `example.com` and answer the relevant question. Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. """ - agent_state = setup_agent(client, filename, memory_persona_str=persona, tools=[tool_name]) + agent_state = setup_agent(client, filename, memory_persona_str=persona, tool_ids=[tool.id]) response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?") @@ -191,7 +189,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse: assert_sanity_checks(response) # Make sure the tool was called - assert_invoked_function_call(response.messages, tool_name) + assert_invoked_function_call(response.messages, tool.name) # Make sure some inner monologue is present assert_inner_monologue_is_present_and_valid(response.messages) @@ -334,7 +332,7 @@ def check_agent_summarize_memory_simple(filename: str) -> LettaResponse: client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?") # Summarize - agent = client.server.load_agent(agent_id=agent_state.id) + agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user) agent.summarize_messages_inplace() print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n") diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 4269fdd8..803fc98c 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -3,6 +3,7 @@ from typing import Union from letta import LocalClient, RESTClient from letta.functions.functions import parse_source_code from letta.functions.schema_generator import generate_schema +from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.tool import Tool @@ -24,3 +25,57 @@ def create_tool_from_func(func: callable): source_code=parse_source_code(func), json_schema=generate_schema(func, None), ) + + +def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]): + # Assert scalar fields + assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}" + assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}" + assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}" + + # Assert agent type + if hasattr(request, "agent_type"): + assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}" + + # Assert LLM configuration + assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}" + + # Assert embedding configuration + assert ( + agent.embedding_config == request.embedding_config + ), f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}" + + # Assert memory blocks + if hasattr(request, "memory_blocks"): + assert len(agent.memory.blocks) == len(request.memory_blocks) + len( + request.block_ids + ), f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}" + memory_block_values = {block.value for block in agent.memory.blocks} + expected_block_values = {block.value for block in request.memory_blocks} + assert expected_block_values.issubset( + memory_block_values + ), f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}" + + # Assert tools + assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}" + assert {tool.id for tool in agent.tools} == set( + request.tool_ids + ), f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}" + + # Assert sources + assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}" + assert {source.id for source in agent.sources} == set( + request.source_ids + ), f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}" + + # Assert tags + assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}" + + # Assert tool rules + if request.tool_rules: + assert len(agent.tool_rules) == len( + request.tool_rules + ), f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}" + assert all( + any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules + ), f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}" diff --git a/tests/test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py similarity index 98% rename from tests/test_agent_tool_graph.py rename to tests/integration_test_agent_tool_graph.py index 7774977c..ff8700c1 100644 --- a/tests/test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -99,7 +99,7 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none): ] # Make agent state - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") # Make checks diff --git a/tests/test_o1_agent.py b/tests/integration_test_o1_agent.py similarity index 95% rename from tests/test_o1_agent.py rename to tests/integration_test_o1_agent.py index 86212ffc..6c8c62a1 100644 --- a/tests/test_o1_agent.py +++ b/tests/integration_test_o1_agent.py @@ -17,7 +17,7 @@ def test_o1_agent(): agent_state = client.create_agent( agent_type=AgentType.o1_agent, - tools=[thinking_tool.name, final_tool.name], + tool_ids=[thinking_tool.id, final_tool.id], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), memory=ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text("o1_persona")), diff --git a/tests/test_offline_memory_agent.py b/tests/integration_test_offline_memory_agent.py similarity index 92% rename from tests/test_offline_memory_agent.py rename to tests/integration_test_offline_memory_agent.py index ff3aca5c..8a4fb81c 100644 --- a/tests/test_offline_memory_agent.py +++ b/tests/integration_test_offline_memory_agent.py @@ -32,8 +32,10 @@ def clear_agents(client): for agent in client.list_agents(): client.delete_agent(agent.id) + def test_ripple_edit(client, mock_e2b_api_key_none): trigger_rethink_memory_tool = client.create_or_update_tool(trigger_rethink_memory) + send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user) conversation_human_block = Block(name="human", label="human", value=get_human_text(DEFAULT_HUMAN), limit=2000) conversation_persona_block = Block(name="persona", label="persona", value=get_persona_text(DEFAULT_PERSONA), limit=2000) @@ -64,7 +66,7 @@ def test_ripple_edit(client, mock_e2b_api_key_none): system=gpt_system.get_system_text("memgpt_convo_only"), llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - tools=["send_message", trigger_rethink_memory_tool.name], + tool_ids=[send_message.id, trigger_rethink_memory_tool.id], memory=conversation_memory, include_base_tools=False, ) @@ -81,7 +83,7 @@ def test_ripple_edit(client, mock_e2b_api_key_none): memory=offline_memory, llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - tools=[rethink_memory_tool.name, finish_rethinking_memory_tool.name], + tool_ids=[rethink_memory_tool.id, finish_rethinking_memory_tool.id], tool_rules=[TerminalToolRule(tool_name=finish_rethinking_memory_tool.name)], include_base_tools=False, ) @@ -111,16 +113,16 @@ def test_chat_only_agent(client, mock_e2b_api_key_none): ) conversation_memory = BasicBlockMemory(blocks=[conversation_persona_block, conversation_human_block]) - client = create_client() + send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user) chat_only_agent = client.create_agent( name="conversation_agent", agent_type=AgentType.chat_only_agent, llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - tools=["send_message"], + tool_ids=[send_message.id], memory=conversation_memory, include_base_tools=False, - metadata={"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]}, + metadata={"offline_memory_tools": [rethink_memory.id, finish_rethinking_memory.id]}, ) assert chat_only_agent is not None assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"} @@ -135,6 +137,7 @@ def test_chat_only_agent(client, mock_e2b_api_key_none): # Clean up agent client.delete_agent(chat_only_agent.id) + def test_initial_message_sequence(client, mock_e2b_api_key_none): """ Test that when we set the initial sequence to an empty list, @@ -150,8 +153,6 @@ def test_initial_message_sequence(client, mock_e2b_api_key_none): initial_message_sequence=[], ) assert offline_memory_agent is not None - assert len(offline_memory_agent.message_ids) == 1 # There should just the system message + assert len(offline_memory_agent.message_ids) == 1 # There should just the system message client.delete_agent(offline_memory_agent.id) - - diff --git a/tests/test_autogen_integration.py b/tests/test_autogen_integration.py deleted file mode 100644 index 251f7590..00000000 --- a/tests/test_autogen_integration.py +++ /dev/null @@ -1,41 +0,0 @@ -# TODO: add back - -# import os -# import subprocess -# -# import pytest -# -# -# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key") -# def test_agent_groupchat(): -# -# # Define the path to the script you want to test -# script_path = "letta/autogen/examples/agent_groupchat.py" -# -# # Dynamically get the project's root directory (assuming this script is run from the root) -# # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -# # print(project_root) -# # project_root = os.path.join(project_root, "Letta") -# # print(project_root) -# # sys.exit(1) -# -# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -# project_root = os.path.join(project_root, "letta") -# print(f"Adding the following to PATH: {project_root}") -# -# # Prepare the environment, adding the project root to PYTHONPATH -# env = os.environ.copy() -# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}" -# -# # Run the script using subprocess.run -# # Capture the output (stdout) and the exit code -# # result = subprocess.run(["python", script_path], capture_output=True, text=True) -# result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True) -# -# # Check the exit code (0 indicates success) -# assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}" -# -# # Optionally, check the output for expected content -# # For example, if you expect a specific line in the output, uncomment and adapt the following line: -# # assert "expected output" in result.stdout, "Expected output not found in script's output" -# diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 199800eb..81446719 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -23,7 +23,7 @@ def agent_obj(): agent_state = client.create_agent() global agent_obj - agent_obj = client.server.load_agent(agent_id=agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) yield agent_obj client.delete_agent(agent_obj.agent_state.id) @@ -35,49 +35,50 @@ def query_in_search_results(search_results, query): return True return False + def test_archival(agent_obj): """Test archival memory functions comprehensively.""" # Test 1: Basic insertion and retrieval base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat") base_functions.archival_memory_insert(agent_obj, "The dog plays in the park") base_functions.archival_memory_insert(agent_obj, "Python is a programming language") - + # Test exact text search results, _ = base_functions.archival_memory_search(agent_obj, "cat") assert query_in_search_results(results, "cat") - + # Test semantic search (should return animal-related content) results, _ = base_functions.archival_memory_search(agent_obj, "animal pets") assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog") - + # Test unrelated search (should not return animal content) results, _ = base_functions.archival_memory_search(agent_obj, "programming computers") assert query_in_search_results(results, "python") - + # Test 2: Test pagination # Insert more items to test pagination for i in range(10): base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}") - + # Get first page page0_results, next_page = base_functions.archival_memory_search(agent_obj, "Test passage", page=0) # Get second page page1_results, _ = base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page) - + assert page0_results != page1_results assert query_in_search_results(page0_results, "Test passage") assert query_in_search_results(page1_results, "Test passage") - + # Test 3: Test complex text patterns base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John") base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week") base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching") - + # Search for meeting-related content results, _ = base_functions.archival_memory_search(agent_obj, "meeting schedule") assert query_in_search_results(results, "meeting") assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week") - + # Test 4: Test error handling # Test invalid page number try: @@ -85,7 +86,7 @@ def test_archival(agent_obj): assert False, "Should have raised ValueError" except ValueError: pass - + def test_recall(agent_obj): base_functions.conversation_search(agent_obj, "banana") diff --git a/tests/test_client.py b/tests/test_client.py index 2c92ef95..526559b7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,4 @@ import asyncio -import json import os import threading import time @@ -42,8 +41,8 @@ def run_server(): @pytest.fixture( - # params=[{"server": False}, {"server": True}], # whether to use REST API server - params=[{"server": False}], # whether to use REST API server + params=[{"server": False}, {"server": True}], # whether to use REST API server + # params=[{"server": True}], # whether to use REST API server scope="module", ) def client(request): @@ -69,6 +68,7 @@ def client(request): @pytest.fixture(scope="module") def agent(client: Union[LocalClient, RESTClient]): agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}") + yield agent_state # delete agent @@ -86,6 +86,47 @@ def clear_tables(): session.commit() +def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]): + # _reset_config() + + # create a block + block = client.create_block(label="human", value="username: sarah") + + # create agents with shared block + from letta.schemas.block import Block + from letta.schemas.memory import BasicBlockMemory + + # persona1_block = client.create_block(label="persona", value="you are agent 1") + # persona2_block = client.create_block(label="persona", value="you are agent 2") + # create agents + agent_state1 = client.create_agent( + name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id] + ) + agent_state2 = client.create_agent( + name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id] + ) + + ## attach shared block to both agents + # client.link_agent_memory_block(agent_state1.id, block.id) + # client.link_agent_memory_block(agent_state2.id, block.id) + + # update memory + client.user_message(agent_id=agent_state1.id, message="my name is actually charles") + + # check agent 2 memory + assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}" + + client.user_message(agent_id=agent_state2.id, message="whats my name?") + assert ( + "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() + ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" + # assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}" + + # cleanup + client.delete_agent(agent_state1.id) + client.delete_agent(agent_state2.id) + + def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): """ Test sandbox config and environment variable functions for both LocalClient and RESTClient. @@ -137,15 +178,15 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient] client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) -def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]): """ Comprehensive happy path test for adding, retrieving, and managing tags on an agent. """ tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"] - # Step 0: create an agent with tags - tagged_agent = client.create_agent(tags=tags_to_add) - assert set(tagged_agent.tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {tagged_agent.tags}" + # Step 0: create an agent with no tags + agent = client.create_agent() + assert len(agent.tags) == 0 # Step 1: Add multiple tags to the agent client.update_agent(agent_id=agent.id, tags=tags_to_add) @@ -175,6 +216,9 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a final_tags = client.get_agent(agent_id=agent.id).tags assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" + # Remove agent + client.delete_agent(agent.id) + def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState): """Test that we can update the label of a block in an agent's memory""" @@ -255,35 +299,33 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a # client.delete_agent(new_agent.id) -def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]): """Test that we can update the limit of a block in an agent's memory""" - agent = client.create_agent(name=create_random_username()) + agent = client.create_agent() - try: - current_labels = agent.memory.list_block_labels() - example_label = current_labels[0] - example_new_limit = 1 - current_block = agent.memory.get_block(label=example_label) - current_block_length = len(current_block.value) + current_labels = agent.memory.list_block_labels() + example_label = current_labels[0] + example_new_limit = 1 + current_block = agent.memory.get_block(label=example_label) + current_block_length = len(current_block.value) - assert example_new_limit != agent.memory.get_block(label=example_label).limit - assert example_new_limit < current_block_length + assert example_new_limit != agent.memory.get_block(label=example_label).limit + assert example_new_limit < current_block_length - # We expect this to throw a value error - with pytest.raises(ValueError): - client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) - - # Now try the same thing with a higher limit - example_new_limit = current_block_length + 10000 - assert example_new_limit > current_block_length + # We expect this to throw a value error + with pytest.raises(ValueError): client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) - updated_agent = client.get_agent(agent_id=agent.id) - assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit + # Now try the same thing with a higher limit + example_new_limit = current_block_length + 10000 + assert example_new_limit > current_block_length + client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) - finally: - client.delete_agent(agent.id) + updated_agent = client.get_agent(agent_id=agent.id) + assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit + + client.delete_agent(agent.id) def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): @@ -316,7 +358,7 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50 tool = client.create_or_update_tool(func=big_return, return_char_limit=1000) - agent = client.create_agent(name="agent1", tools=[tool.name]) + agent = client.create_agent(tool_ids=[tool.id]) # get function response response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user") print(response.messages) @@ -330,10 +372,14 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): assert response_message, "FunctionReturn message not found in response" res = response_message.function_return assert "function output was truncated " in res - res_json = json.loads(res) - assert ( - len(res_json["message"]) <= 1000 + padding - ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}" + + # TODO: Re-enable later + # res_json = json.loads(res) + # assert ( + # len(res_json["message"]) <= 1000 + padding + # ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}" + + client.delete_agent(agent_id=agent.id) @pytest.mark.asyncio diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 2ee92293..3839611b 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -583,43 +583,6 @@ def test_list_llm_models(client: RESTClient): assert has_model_endpoint_type(models, "anthropic") -def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - # create a block - block = client.create_block(label="human", value="username: sarah") - - # create agents with shared block - from letta.schemas.block import Block - from letta.schemas.memory import BasicBlockMemory - - # persona1_block = client.create_block(label="persona", value="you are agent 1") - # persona2_block = client.create_block(label="persona", value="you are agent 2") - # create agnets - agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1"), block])) - agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2"), block])) - - ## attach shared block to both agents - # client.link_agent_memory_block(agent_state1.id, block.id) - # client.link_agent_memory_block(agent_state2.id, block.id) - - # update memory - response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles") - - # check agent 2 memory - assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}" - - response = client.user_message(agent_id=agent_state2.id, message="whats my name?") - assert ( - "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() - ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" - # assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}" - - # cleanup - client.delete_agent(agent_state1.id) - client.delete_agent(agent_state2.id) - - @pytest.fixture def cleanup_agents(client): created_agents = [] diff --git a/tests/test_concurrent_connections.py b/tests/test_concurrent_connections.py deleted file mode 100644 index c3b4f8df..00000000 --- a/tests/test_concurrent_connections.py +++ /dev/null @@ -1,142 +0,0 @@ -# TODO: add back when messaging works - -# import os -# import threading -# import time -# import uuid -# -# import pytest -# from dotenv import load_dotenv -# -# from letta import Admin, create_client -# from letta.config import LettaConfig -# from letta.credentials import LettaCredentials -# from letta.settings import settings -# from tests.utils import create_config -# -# test_agent_name = f"test_client_{str(uuid.uuid4())}" -## test_preset_name = "test_preset" -# test_agent_state = None -# client = None -# -# test_agent_state_post_message = None -# test_user_id = uuid.uuid4() -# -# -## admin credentials -# test_server_token = "test_server_token" -# -# -# def _reset_config(): -# -# # Use os.getenv with a fallback to os.environ.get -# db_url = settings.letta_pg_uri -# -# if os.getenv("OPENAI_API_KEY"): -# create_config("openai") -# credentials = LettaCredentials( -# openai_key=os.getenv("OPENAI_API_KEY"), -# ) -# else: # hosted -# create_config("letta_hosted") -# credentials = LettaCredentials() -# -# config = LettaConfig.load() -# -# # set to use postgres -# config.archival_storage_uri = db_url -# config.recall_storage_uri = db_url -# config.metadata_storage_uri = db_url -# config.archival_storage_type = "postgres" -# config.recall_storage_type = "postgres" -# config.metadata_storage_type = "postgres" -# -# config.save() -# credentials.save() -# print("_reset_config :: ", config.config_path) -# -# -# def run_server(): -# -# load_dotenv() -# -# _reset_config() -# -# from letta.server.rest_api.server import start_server -# -# print("Starting server...") -# start_server(debug=True) -# -# -## Fixture to create clients with different configurations -# @pytest.fixture( -# params=[ # whether to use REST API server -# {"server": True}, -# # {"server": False} # TODO: add when implemented -# ], -# scope="module", -# ) -# def admin_client(request): -# if request.param["server"]: -# # get URL from enviornment -# server_url = os.getenv("MEMGPT_SERVER_URL") -# if server_url is None: -# # run server in thread -# # NOTE: must set MEMGPT_SERVER_PASS enviornment variable -# server_url = "http://localhost:8283" -# print("Starting server thread") -# thread = threading.Thread(target=run_server, daemon=True) -# thread.start() -# time.sleep(5) -# print("Running client tests with server:", server_url) -# # create user via admin client -# admin = Admin(server_url, test_server_token) -# response = admin.create_user(test_user_id) # Adjust as per your client's method -# -# yield admin -# -# -# def test_concurrent_messages(admin_client): -# # test concurrent messages -# -# # create three -# -# results = [] -# -# def _send_message(): -# try: -# print("START SEND MESSAGE") -# response = admin_client.create_user() -# token = response.api_key -# client = create_client(base_url=admin_client.base_url, token=token) -# agent = client.create_agent() -# -# print("Agent created", agent.id) -# -# st = time.time() -# message = "Hello, how are you?" -# response = client.send_message(agent_id=agent.id, message=message, role="user") -# et = time.time() -# print(f"Message sent from {st} to {et}") -# print(response.messages) -# results.append((st, et)) -# except Exception as e: -# print("ERROR", e) -# -# threads = [] -# print("Starting threads...") -# for i in range(5): -# thread = threading.Thread(target=_send_message) -# threads.append(thread) -# thread.start() -# print("CREATED THREAD") -# -# print("waiting for threads to finish...") -# for thread in threads: -# print(thread.join()) -# -# # make sure runtime are overlapping -# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or ( -# results[1][0] < results[0][0] and results[1][1] > results[0][0] -# ), f"Threads should have overlapping runtimes {results}" -# diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py deleted file mode 100644 index 58748339..00000000 --- a/tests/test_different_embedding_size.py +++ /dev/null @@ -1,121 +0,0 @@ -# TODO: add back once tests are cleaned up - -# import os -# import uuid -# -# from letta import create_client -# from letta.agent_store.storage import StorageConnector, TableType -# from letta.schemas.passage import Passage -# from letta.embeddings import embedding_model -# from tests import TEST_MEMGPT_CONFIG -# -# from .utils import create_config, wipe_config -# -# test_agent_name = f"test_client_{str(uuid.uuid4())}" -# test_agent_state = None -# client = None -# -# test_agent_state_post_message = None -# test_user_id = uuid.uuid4() -# -# -# def generate_passages(user, agent): -# # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. -# texts = [ -# "This is a test passage", -# "This is another test passage", -# "Cinderella wept", -# ] -# embed_model = embedding_model(agent.embedding_config) -# orig_embeddings = [] -# passages = [] -# for text in texts: -# embedding = embed_model.get_text_embedding(text) -# orig_embeddings.append(list(embedding)) -# passages.append( -# Passage( -# user_id=user.id, -# agent_id=agent.id, -# text=text, -# embedding=embedding, -# embedding_dim=agent.embedding_config.embedding_dim, -# embedding_model=agent.embedding_config.embedding_model, -# ) -# ) -# return passages, orig_embeddings -# -# -# def test_create_user(): -# if not os.getenv("OPENAI_API_KEY"): -# print("Skipping test, missing OPENAI_API_KEY") -# return -# -# wipe_config() -# -# # create client -# create_config("openai") -# client = create_client() -# -# # openai: create agent -# openai_agent = client.create_agent( -# name="openai_agent", -# ) -# assert ( -# openai_agent.embedding_config.embedding_endpoint_type == "openai" -# ), f"openai_agent.embedding_config.embedding_endpoint_type={openai_agent.embedding_config.embedding_endpoint_type}" -# -# # openai: add passages -# passages, openai_embeddings = generate_passages(client.user, openai_agent) -# openai_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=openai_agent.id) -# openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) -# -# # create client -# create_config("letta_hosted") -# client = create_client() -# -# # hosted: create agent -# hosted_agent = client.create_agent( -# name="hosted_agent", -# ) -# # check to make sure endpoint overriden -# assert ( -# hosted_agent.embedding_config.embedding_endpoint_type == "hugging-face" -# ), f"hosted_agent.embedding_config.embedding_endpoint_type={hosted_agent.embedding_config.embedding_endpoint_type}" -# -# # hosted: add passages -# passages, hosted_embeddings = generate_passages(client.user, hosted_agent) -# hosted_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=hosted_agent.id) -# hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) -# -# # test passage dimentionality -# storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id) -# storage.filters = {} # clear filters to be able to get all passages -# passages = storage.get_all() -# for passage in passages: -# if passage.agent_id == hosted_agent.id: -# assert ( -# passage.embedding_dim == hosted_agent.embedding_config.embedding_dim -# ), f"passage.embedding_dim={passage.embedding_dim} != hosted_agent.embedding_config.embedding_dim={hosted_agent.embedding_config.embedding_dim}" -# -# # ensure was in original embeddings -# embedding = passage.embedding[: passage.embedding_dim] -# assert embedding in hosted_embeddings, f"embedding={embedding} not in hosted_embeddings={hosted_embeddings}" -# -# # make sure all zeros -# assert not any( -# passage.embedding[passage.embedding_dim :] -# ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}" -# elif passage.agent_id == openai_agent.id: -# assert ( -# passage.embedding_dim == openai_agent.embedding_config.embedding_dim -# ), f"passage.embedding_dim={passage.embedding_dim} != openai_agent.embedding_config.embedding_dim={openai_agent.embedding_config.embedding_dim}" -# -# # ensure was in original embeddings -# embedding = passage.embedding[: passage.embedding_dim] -# assert embedding in openai_embeddings, f"embedding={embedding} not in openai_embeddings={openai_embeddings}" -# -# # make sure all zeros -# assert not any( -# passage.embedding[passage.embedding_dim :] -# ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}" -# diff --git a/tests/test_function_parser.py b/tests/test_function_parser.py deleted file mode 100644 index 64b22336..00000000 --- a/tests/test_function_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -import letta.system as system -from letta.local_llm.function_parser import patch_function -from letta.utils import json_dumps - -EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = { - "message_history": [ - {"role": "user", "content": system.package_user_message("hello")}, - ], - # "new_message": { - # "role": "function", - # "name": "send_message", - # "content": system.package_function_response(was_success=True, response_string="None"), - # }, - "new_message": { - "role": "assistant", - "content": "I'll send a message.", - "function_call": { - "name": "send_message", - "arguments": "null", - }, - }, -} - -EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = { - "message_history": [ - {"role": "user", "content": system.package_user_message("hello")}, - ], - "new_message": { - "role": "assistant", - "content": "I'll append to memory.", - "function_call": { - "name": "core_memory_append", - "arguments": json_dumps({"content": "new_stuff"}), - }, - }, -} - - -def test_function_parsers(): - """Try various broken JSON and check that the parsers can fix it""" - - og_message = EXAMPLE_FUNCTION_CALL_SEND_MESSAGE["new_message"] - corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_SEND_MESSAGE) - assert corrected_message == og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}" - - og_message = EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING["new_message"].copy() - corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING) - assert corrected_message != og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}" diff --git a/tests/test_json_parsers.py b/tests/test_json_parsers.py deleted file mode 100644 index 64c3b3f7..00000000 --- a/tests/test_json_parsers.py +++ /dev/null @@ -1,99 +0,0 @@ -import letta.local_llm.json_parser as json_parser -from letta.utils import json_loads - -EXAMPLE_ESCAPED_UNDERSCORES = """{ - "function":"send\_message", - "params": { - "inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.", - "message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?" -""" - - -EXAMPLE_MISSING_CLOSING_BRACE = """{ - "function": "send_message", - "params": { - "inner_thoughts": "Oops, I got their name wrong! I should apologize and correct myself.", - "message": "Sorry about that! I assumed you were Chad. Welcome, Brad! " - } -""" - -EXAMPLE_BAD_TOKEN_END = """{ - "function": "send_message", - "params": { - "inner_thoughts": "Oops, I got their name wrong! I should apologize and correct myself.", - "message": "Sorry about that! I assumed you were Chad. Welcome, Brad! " - } -}<|>""" - -EXAMPLE_DOUBLE_JSON = """{ - "function": "core_memory_append", - "params": { - "name": "human", - "content": "Brad, 42 years old, from Germany." - } -} -{ - "function": "send_message", - "params": { - "message": "Got it! Your age and nationality are now saved in my memory." - } -} -""" - -EXAMPLE_HARD_LINE_FEEDS = """{ - "function": "send_message", - "params": { - "message": "Let's create a list: -- First, we can do X -- Then, we can do Y! -- Lastly, we can do Z :)" - } -} -""" - -# Situation where beginning of send_message call is fine (and thus can be extracted) -# but has a long training garbage string that comes after -EXAMPLE_SEND_MESSAGE_PREFIX_OK_REST_BAD = """{ - "function": "send_message", - "params": { - "inner_thoughts": "User request for debug assistance", - "message": "Of course, Chad. Please check the system log file for 'assistant.json' and send me the JSON output you're getting. Armed with that data, I'll assist you in debugging the issue.", -GARBAGEGARBAGEGARBAGEGARBAGE -GARBAGEGARBAGEGARBAGEGARBAGE -GARBAGEGARBAGEGARBAGEGARBAGE -""" - -EXAMPLE_ARCHIVAL_SEARCH = """ - -{ - "function": "archival_memory_search", - "params": { - "inner_thoughts": "Looking for WaitingForAction.", - "query": "WaitingForAction", -""" - - -def test_json_parsers(): - """Try various broken JSON and check that the parsers can fix it""" - - test_strings = [ - EXAMPLE_ESCAPED_UNDERSCORES, - EXAMPLE_MISSING_CLOSING_BRACE, - EXAMPLE_BAD_TOKEN_END, - EXAMPLE_DOUBLE_JSON, - EXAMPLE_HARD_LINE_FEEDS, - EXAMPLE_SEND_MESSAGE_PREFIX_OK_REST_BAD, - EXAMPLE_ARCHIVAL_SEARCH, - ] - - for string in test_strings: - try: - json_loads(string) - assert False, f"Test JSON string should have failed basic JSON parsing:\n{string}" - except: - print("String failed (expectedly)") - try: - json_parser.clean_json(string) - except: - f"Failed to repair test JSON string:\n{string}" - raise diff --git a/tests/test_local_client.py b/tests/test_local_client.py index 3aa947ba..ea5d04e0 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -4,7 +4,7 @@ import pytest from letta import create_client from letta.client.client import LocalClient -from letta.schemas.agent import PersistedAgentState +from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory @@ -13,6 +13,7 @@ from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory @pytest.fixture(scope="module") def client(): client = create_client() + # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) @@ -29,7 +30,6 @@ def agent(client): yield agent_state client.delete_agent(agent_state.id) - assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}" def test_agent(client: LocalClient): @@ -80,16 +80,15 @@ def test_agent(client: LocalClient): assert isinstance(agent_state.memory, Memory) # update agent: tools tool_to_delete = "send_message" - assert tool_to_delete in agent_state.tool_names - new_agent_tools = [t_name for t_name in agent_state.tool_names if t_name != tool_to_delete] - client.update_agent(agent_state_test.id, tools=new_agent_tools) - assert client.get_agent(agent_state_test.id).tool_names == new_agent_tools + assert tool_to_delete in [t.name for t in agent_state.tools] + new_agent_tool_ids = [t.id for t in agent_state.tools if t.name != tool_to_delete] + client.update_agent(agent_state_test.id, tool_ids=new_agent_tool_ids) + assert sorted([t.id for t in client.get_agent(agent_state_test.id).tools]) == sorted(new_agent_tool_ids) assert isinstance(agent_state.memory, Memory) # update agent: memory new_human = "My name is Mr Test, 100 percent human." new_persona = "I am an all-knowing AI." - new_memory = ChatMemory(human=new_human, persona=new_persona) assert agent_state.memory.get_block("human").value != new_human assert agent_state.memory.get_block("persona").value != new_persona @@ -216,7 +215,7 @@ def test_agent_with_shared_blocks(client: LocalClient): client.delete_agent(second_agent_state_test.id) -def test_memory(client: LocalClient, agent: PersistedAgentState): +def test_memory(client: LocalClient, agent: AgentState): # get agent memory original_memory = client.get_in_context_memory(agent.id) assert original_memory is not None @@ -229,7 +228,7 @@ def test_memory(client: LocalClient, agent: PersistedAgentState): assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated -def test_archival_memory(client: LocalClient, agent: PersistedAgentState): +def test_archival_memory(client: LocalClient, agent: AgentState): """Test functions for interacting with archival memory store""" # add archival memory @@ -244,7 +243,7 @@ def test_archival_memory(client: LocalClient, agent: PersistedAgentState): client.delete_archival_memory(agent.id, passage.id) -def test_recall_memory(client: LocalClient, agent: PersistedAgentState): +def test_recall_memory(client: LocalClient, agent: AgentState): """Test functions for interacting with recall memory store""" # send message to the agent diff --git a/tests/test_managers.py b/tests/test_managers.py index 745b17d7..96b5faa4 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4,12 +4,13 @@ from datetime import datetime, timedelta import pytest from sqlalchemy import delete +from sqlalchemy.exc import IntegrityError +from letta.config import LettaConfig from letta.embeddings import embedding_model -import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.metadata import AgentModel from letta.orm import ( + Agent, Block, BlocksAgents, FileMetadata, @@ -20,17 +21,14 @@ from letta.orm import ( SandboxConfig, SandboxEnvironmentVariable, Source, + SourcesAgents, Tool, ToolsAgents, User, ) from letta.orm.agents_tags import AgentsTags -from letta.orm.errors import ( - ForeignKeyConstraintViolationError, - NoResultFound, - UniqueConstraintViolationError, -) -from letta.schemas.agent import CreateAgent +from letta.orm.errors import NoResultFound, UniqueConstraintViolationError +from letta.schemas.agent import CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig @@ -40,7 +38,7 @@ from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage -from letta.schemas.message import MessageUpdate +from letta.schemas.message import MessageCreate, MessageUpdate from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.sandbox_config import ( @@ -56,17 +54,14 @@ from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate -from letta.services.block_manager import BlockManager -from letta.services.organization_manager import OrganizationManager -from letta.services.passage_manager import PassageManager -from letta.services.tool_manager import ToolManager -from letta.settings import tool_settings - -utils.DEBUG = True -from letta.config import LettaConfig +from letta.schemas.tool_rule import InitToolRule from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate from letta.server.server import SyncServer +from letta.services.block_manager import BlockManager +from letta.services.organization_manager import OrganizationManager +from letta.settings import tool_settings +from tests.helpers.utils import comprehensive_agent_checks DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( embedding_endpoint_type="hugging-face", @@ -91,6 +86,7 @@ def clear_tables(server: SyncServer): session.execute(delete(Job)) session.execute(delete(ToolsAgents)) # Clear ToolsAgents first session.execute(delete(BlocksAgents)) + session.execute(delete(SourcesAgents)) session.execute(delete(AgentsTags)) session.execute(delete(SandboxEnvironmentVariable)) session.execute(delete(SandboxConfig)) @@ -98,7 +94,7 @@ def clear_tables(server: SyncServer): session.execute(delete(FileMetadata)) session.execute(delete(Source)) session.execute(delete(Tool)) # Clear all records from the Tool table - session.execute(delete(AgentModel)) + session.execute(delete(Agent)) session.execute(delete(User)) # Clear all records from the user table session.execute(delete(Organization)) # Clear all records from the organization table session.commit() # Commit the deletion @@ -137,47 +133,27 @@ def default_source(server: SyncServer, default_user): yield source +@pytest.fixture +def other_source(server: SyncServer, default_user): + source_pydantic = PydanticSource( + name="Another Test Source", + description="This is yet another test source.", + metadata_={"type": "another_test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + yield source + + @pytest.fixture def default_file(server: SyncServer, default_source, default_user, default_organization): file = server.source_manager.create_file( - PydanticFileMetadata( - file_name="test_file", organization_id=default_organization.id, source_id=default_source.id), + PydanticFileMetadata(file_name="test_file", organization_id=default_organization.id, source_id=default_source.id), actor=default_user, ) yield file -@pytest.fixture -def sarah_agent(server: SyncServer, default_user, default_organization): - """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.create_agent( - request=CreateAgent( - name="sarah_agent", - # memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], - memory_blocks=[], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - ), - actor=default_user, - ) - yield agent_state - - -@pytest.fixture -def charles_agent(server: SyncServer, default_user, default_organization): - """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.create_agent( - request=CreateAgent( - name="charles_agent", - memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - ), - actor=default_user, - ) - yield agent_state - - @pytest.fixture def print_tool(server: SyncServer, default_user, default_organization): """Fixture to create a tool with default settings and clean up after the test.""" @@ -221,9 +197,9 @@ def hello_world_passage_fixture(server: SyncServer, default_user, default_file, organization_id=default_user.organization_id, agent_id=sarah_agent.id, file_id=default_file.id, - text="Hello, world!", - embedding=dummy_embedding, - embedding_config=DEFAULT_EMBEDDING_CONFIG + text="Hello, world!", + embedding=dummy_embedding, + embedding_config=DEFAULT_EMBEDDING_CONFIG, ) msg = server.passage_manager.create_passage(message, actor=default_user) @@ -239,14 +215,16 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a organization_id=default_user.organization_id, agent_id=sarah_agent.id, file_id=default_file.id, - text=f"Test passage {i}", - embedding=dummy_embedding, - embedding_config=DEFAULT_EMBEDDING_CONFIG - ) for i in range(4) + text=f"Test passage {i}", + embedding=dummy_embedding, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + for i in range(4) ] server.passage_manager.create_many_passages(passages, actor=default_user) return passages + @pytest.fixture def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent): """Fixture to create a tool with default settings and clean up after the test.""" @@ -346,6 +324,61 @@ def other_tool(server: SyncServer, default_user, default_organization): yield tool +@pytest.fixture +def sarah_agent(server: SyncServer, default_user, default_organization): + """Fixture to create and return a sample agent within the default organization.""" + agent_state = server.create_agent( + request=CreateAgent( + name="sarah_agent", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + yield agent_state + + +@pytest.fixture +def charles_agent(server: SyncServer, default_user, default_organization): + """Fixture to create and return a sample agent within the default organization.""" + agent_state = server.create_agent( + request=CreateAgent( + name="charles_agent", + memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + yield agent_state + + +@pytest.fixture +def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block): + memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] + create_agent_request = CreateAgent( + system="test system", + memory_blocks=memory_blocks, + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + block_ids=[default_block.id], + tool_ids=[print_tool.id], + source_ids=[default_source.id], + tags=["a", "b"], + description="test_description", + metadata_={"test_key": "test_value"}, + tool_rules=[InitToolRule(tool_name=print_tool.name)], + initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")], + ) + created_agent = server.agent_manager.create_agent( + create_agent_request, + actor=default_user, + ) + + yield created_agent, create_agent_request + + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() @@ -356,6 +389,389 @@ def server(): return server +# ====================================================================================================================== +# AgentManager Tests - Basic +# ====================================================================================================================== +def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user): + # Test agent creation + created_agent, create_agent_request = comprehensive_test_agent_fixture + comprehensive_agent_checks(created_agent, create_agent_request) + + # Test get agent + get_agent = server.agent_manager.get_agent_by_id(agent_id=created_agent.id, actor=default_user) + comprehensive_agent_checks(get_agent, create_agent_request) + + # Test get agent name + get_agent_name = server.agent_manager.get_agent_by_name(agent_name=created_agent.name, actor=default_user) + comprehensive_agent_checks(get_agent_name, create_agent_request) + + # Test list agent + list_agents = server.agent_manager.list_agents(actor=default_user) + assert len(list_agents) == 1 + comprehensive_agent_checks(list_agents[0], create_agent_request) + + # Test deleting the agent + server.agent_manager.delete_agent(get_agent.id, default_user) + list_agents = server.agent_manager.list_agents(actor=default_user) + assert len(list_agents) == 0 + + +def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user): + agent, _ = comprehensive_test_agent_fixture + update_agent_request = UpdateAgent( + name="train_agent", + description="train description", + tool_ids=[other_tool.id], + source_ids=[other_source.id], + block_ids=[other_block.id], + tool_rules=[InitToolRule(tool_name=other_tool.name)], + tags=["c", "d"], + system="train system", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(model_name="letta"), + message_ids=["10", "20"], + metadata_={"train_key": "train_value"}, + ) + + updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user) + comprehensive_agent_checks(updated_agent, update_agent_request) + assert updated_agent.message_ids == update_agent_request.message_ids + + +# ====================================================================================================================== +# AgentManager Tests - Sources Relationship +# ====================================================================================================================== + + +def test_attach_source(server: SyncServer, sarah_agent, default_source, default_user): + """Test attaching a source to an agent.""" + # Attach the source + server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + + # Verify attachment through get_agent_by_id + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert default_source.id in [s.id for s in agent.sources] + + # Verify that attaching the same source again doesn't cause issues + server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert len([s for s in agent.sources if s.id == default_source.id]) == 1 + + +def test_list_attached_source_ids(server: SyncServer, sarah_agent, default_source, other_source, default_user): + """Test listing source IDs attached to an agent.""" + # Initially should have no sources + sources = server.agent_manager.list_attached_sources(sarah_agent.id, actor=default_user) + assert len(sources) == 0 + + # Attach sources + server.agent_manager.attach_source(sarah_agent.id, default_source.id, actor=default_user) + server.agent_manager.attach_source(sarah_agent.id, other_source.id, actor=default_user) + + # List sources and verify + sources = server.agent_manager.list_attached_sources(sarah_agent.id, actor=default_user) + assert len(sources) == 2 + source_ids = [s.id for s in sources] + assert default_source.id in source_ids + assert other_source.id in source_ids + + +def test_detach_source(server: SyncServer, sarah_agent, default_source, default_user): + """Test detaching a source from an agent.""" + # Attach source + server.agent_manager.attach_source(sarah_agent.id, default_source.id, actor=default_user) + + # Verify it's attached + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert default_source.id in [s.id for s in agent.sources] + + # Detach source + server.agent_manager.detach_source(sarah_agent.id, default_source.id, actor=default_user) + + # Verify it's detached + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert default_source.id not in [s.id for s in agent.sources] + + # Verify that detaching an already detached source doesn't cause issues + server.agent_manager.detach_source(sarah_agent.id, default_source.id, actor=default_user) + + +def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user): + """Test attaching a source to a nonexistent agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.attach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) + + +def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user): + """Test attaching a nonexistent source to an agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user) + + +def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user): + """Test detaching a source from a nonexistent agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.detach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) + + +def test_list_attached_source_ids_nonexistent_agent(server: SyncServer, default_user): + """Test listing sources for a nonexistent agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.list_attached_sources(agent_id="nonexistent-agent-id", actor=default_user) + + +def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user): + """Test listing agents that have a particular source attached.""" + # Initially should have no attached agents + attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + assert len(attached_agents) == 0 + + # Attach source to first agent + server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + + # Verify one agent is now attached + attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + assert len(attached_agents) == 1 + assert sarah_agent.id in [a.id for a in attached_agents] + + # Attach source to second agent + server.agent_manager.attach_source(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user) + + # Verify both agents are now attached + attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + assert len(attached_agents) == 2 + attached_agent_ids = [a.id for a in attached_agents] + assert sarah_agent.id in attached_agent_ids + assert charles_agent.id in attached_agent_ids + + # Detach source from first agent + server.agent_manager.detach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + + # Verify only second agent remains attached + attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + assert len(attached_agents) == 1 + assert charles_agent.id in [a.id for a in attached_agents] + + +def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user): + """Test listing agents for a nonexistent source.""" + with pytest.raises(NoResultFound): + server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user) + + +# ====================================================================================================================== +# AgentManager Tests - Tags Relationship +# ====================================================================================================================== + + +def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user): + """Test listing agents that have ALL specified tags.""" + # Create agents with multiple tags + server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user) + server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user) + + # Search for agents with all specified tags + agents = server.agent_manager.list_agents(tags=["test", "gpt4"], match_all_tags=True, actor=default_user) + assert len(agents) == 2 + agent_ids = [a.id for a in agents] + assert sarah_agent.id in agent_ids + assert charles_agent.id in agent_ids + + # Search for tags that only sarah_agent has + agents = server.agent_manager.list_agents(tags=["test", "production"], match_all_tags=True, actor=default_user) + assert len(agents) == 1 + assert agents[0].id == sarah_agent.id + + +def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user): + """Test listing agents that have ANY of the specified tags.""" + # Create agents with different tags + server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) + server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) + + # Search for agents with any of the specified tags + agents = server.agent_manager.list_agents(tags=["production", "development"], match_all_tags=False, actor=default_user) + assert len(agents) == 2 + agent_ids = [a.id for a in agents] + assert sarah_agent.id in agent_ids + assert charles_agent.id in agent_ids + + # Search for tags where only sarah_agent matches + agents = server.agent_manager.list_agents(tags=["production", "nonexistent"], match_all_tags=False, actor=default_user) + assert len(agents) == 1 + assert agents[0].id == sarah_agent.id + + +def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user): + """Test listing agents when no tags match.""" + # Create agents with tags + server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) + server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) + + # Search for nonexistent tags + agents = server.agent_manager.list_agents(tags=["nonexistent1", "nonexistent2"], match_all_tags=True, actor=default_user) + assert len(agents) == 0 + + agents = server.agent_manager.list_agents(tags=["nonexistent1", "nonexistent2"], match_all_tags=False, actor=default_user) + assert len(agents) == 0 + + +def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user): + """Test combining tag search with other filters.""" + # Create agents with specific names and tags + server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user) + server.agent_manager.update_agent(charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user) + + # List agents with specific tag and name pattern + agents = server.agent_manager.list_agents(actor=default_user, tags=["production"], match_all_tags=True, name="production_agent") + assert len(agents) == 1 + assert agents[0].id == sarah_agent.id + + +def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization): + """Test pagination when listing agents by tags.""" + # Create first agent + agent1 = server.create_agent( + request=CreateAgent( + name="agent1", + tags=["pagination_test", "tag1"], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + memory_blocks=[], + ), + actor=default_user, + ) + + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps + + # Create second agent + agent2 = server.create_agent( + request=CreateAgent( + name="agent2", + tags=["pagination_test", "tag2"], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + memory_blocks=[], + ), + actor=default_user, + ) + + # Get first page + first_page = server.agent_manager.list_agents(tags=["pagination_test"], match_all_tags=True, actor=default_user, limit=1) + assert len(first_page) == 1 + first_agent_id = first_page[0].id + + # Get second page using cursor + second_page = server.agent_manager.list_agents( + tags=["pagination_test"], match_all_tags=True, actor=default_user, cursor=first_agent_id, limit=1 + ) + assert len(second_page) == 1 + assert second_page[0].id != first_agent_id + + # Verify we got both agents with no duplicates + all_ids = {first_page[0].id, second_page[0].id} + assert len(all_ids) == 2 + assert agent1.id in all_ids + assert agent2.id in all_ids + + +# ====================================================================================================================== +# AgentManager Tests - Blocks Relationship +# ====================================================================================================================== + + +def test_attach_block(server: SyncServer, sarah_agent, default_block, default_user): + """Test attaching a block to an agent.""" + # Attach block + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + + # Verify attachment + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert len(agent.memory.blocks) == 1 + assert agent.memory.blocks[0].id == default_block.id + assert agent.memory.blocks[0].label == default_block.label + + +def test_attach_block_duplicate_label(server: SyncServer, sarah_agent, default_block, other_block, default_user): + """Test attempting to attach a block with a duplicate label.""" + # Set up both blocks with same label + server.block_manager.update_block(default_block.id, BlockUpdate(label="same_label"), actor=default_user) + server.block_manager.update_block(other_block.id, BlockUpdate(label="same_label"), actor=default_user) + + # Attach first block + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + + # Attempt to attach second block with same label + with pytest.raises(IntegrityError): + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=other_block.id, actor=default_user) + + +def test_detach_block(server: SyncServer, sarah_agent, default_block, default_user): + """Test detaching a block by ID.""" + # Set up: attach block + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + + # Detach block + server.agent_manager.detach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + + # Verify detachment + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert len(agent.memory.blocks) == 0 + + +def test_detach_nonexistent_block(server: SyncServer, sarah_agent, default_user): + """Test detaching a block that isn't attached.""" + with pytest.raises(NoResultFound): + server.agent_manager.detach_block(agent_id=sarah_agent.id, block_id="nonexistent-block-id", actor=default_user) + + +def test_update_block_label(server: SyncServer, sarah_agent, default_block, default_user): + """Test updating a block's label updates the relationship.""" + # Attach block + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + + # Update block label + new_label = "new_label" + server.block_manager.update_block(default_block.id, BlockUpdate(label=new_label), actor=default_user) + + # Verify relationship is updated + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + block = agent.memory.blocks[0] + assert block.id == default_block.id + assert block.label == new_label + + +def test_update_block_label_multiple_agents(server: SyncServer, sarah_agent, charles_agent, default_block, default_user): + """Test updating a block's label updates relationships for all agents.""" + # Attach block to both agents + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + server.agent_manager.attach_block(agent_id=charles_agent.id, block_id=default_block.id, actor=default_user) + + # Update block label + new_label = "new_label" + server.block_manager.update_block(default_block.id, BlockUpdate(label=new_label), actor=default_user) + + # Verify both relationships are updated + for agent_id in [sarah_agent.id, charles_agent.id]: + agent = server.agent_manager.get_agent_by_id(agent_id, actor=default_user) + # Find our specific block by ID + block = next(b for b in agent.memory.blocks if b.id == default_block.id) + assert block.label == new_label + + +def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, default_user): + """Test retrieving a block by its label.""" + # Attach block + server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=default_block.id, actor=default_user) + + # Get block by label + block = server.agent_manager.get_block_with_label(agent_id=sarah_agent.id, block_label=default_block.label, actor=default_user) + + assert block.id == default_block.id + assert block.label == default_block.label + + # ====================================================================================================================== # Organization Manager Tests # ====================================================================================================================== @@ -407,6 +823,7 @@ def test_list_organizations_pagination(server: SyncServer): # Passage Manager Tests # ====================================================================================================================== + def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user): """Test creating a passage using hello_world_passage_fixture fixture""" assert hello_world_passage_fixture.id is not None @@ -489,10 +906,8 @@ def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, last_id_on_first_page = first_page[-1].id # Get second page - second_page = server.passage_manager.list_passages( - actor=default_user, cursor=last_id_on_first_page, limit=3 - ) - assert len(second_page) == 2 # Should have 2 remaining passages + second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3) + assert len(second_page) == 2 # Should have 2 remaining passages assert all(r1.id != r2.id for r1 in first_page for r2 in second_page) @@ -505,16 +920,12 @@ def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixtu def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent): """Test searching passages by text content""" - search_results = server.passage_manager.list_passages( - agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10 - ) + search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10) assert len(search_results) == 4 assert all("Test passage" in msg.text for msg in search_results) - + # Test no results - search_results = server.passage_manager.list_passages( - agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10 - ) + search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10) assert len(search_results) == 0 @@ -522,18 +933,18 @@ def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_pa """Test filtering passages by date range with various scenarios""" # Set up test data with known dates base_time = datetime.utcnow() - + # Create passages at different times passages = [] time_offsets = [ - timedelta(days=-2), # 2 days ago - timedelta(days=-1), # Yesterday - timedelta(hours=-2), # 2 hours ago - timedelta(minutes=-30), # 30 minutes ago + timedelta(days=-2), # 2 days ago + timedelta(days=-1), # Yesterday + timedelta(hours=-2), # 2 hours ago + timedelta(minutes=-30), # 30 minutes ago timedelta(minutes=-1), # 1 minute ago - timedelta(minutes=0), # Now + timedelta(minutes=0), # Now ] - + for i, offset in enumerate(time_offsets): timestamp = base_time + offset passage = server.passage_manager.create_passage( @@ -544,9 +955,9 @@ def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_pa text=f"Test passage {i}", embedding=[0.1, 0.2, 0.3], embedding_config=DEFAULT_EMBEDDING_CONFIG, - created_at=timestamp + created_at=timestamp, ), - actor=default_user + actor=default_user, ) passages.append(passage) @@ -587,42 +998,31 @@ def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_pa "start_date": base_time - timedelta(seconds=30), "end_date": base_time + timedelta(seconds=30), "expected_count": 1 + 1, # date + "now" - } + }, ] # Run test cases for case in test_cases: results = server.passage_manager.list_passages( - agent_id=sarah_agent.id, - actor=default_user, - start_date=case["start_date"], - end_date=case["end_date"], - limit=10 + agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10 ) - + # Verify count - assert len(results) == case["expected_count"], \ - f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}" + assert ( + len(results) == case["expected_count"] + ), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}" # Test edge cases - + # Test with start_date but no end_date results_start_only = server.passage_manager.list_passages( - agent_id=sarah_agent.id, - actor=default_user, - start_date=base_time - timedelta(minutes=2), - end_date=None, - limit=10 + agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10 ) assert len(results_start_only) >= 2, "Should find passages after start_date" # Test with end_date but no start_date results_end_only = server.passage_manager.list_passages( - agent_id=sarah_agent.id, - actor=default_user, - start_date=None, - end_date=base_time - timedelta(days=1), - limit=10 + agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10 ) assert len(results_end_only) >= 1, "Should find passages before end_date" @@ -632,7 +1032,7 @@ def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_pa actor=default_user, start_date=base_time - timedelta(days=3), end_date=base_time + timedelta(days=1), - limit=3 + limit=3, ) assert len(limited_results) <= 3, "Should respect the limit parameter" @@ -640,18 +1040,18 @@ def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_pa def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent): """Test vector search functionality for passages.""" passage_manager = server.passage_manager - embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) - + embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) + # Create passages with known embeddings passages = [] - + # Create passages with different embeddings test_passages = [ "I like red", "random text", "blue shoes", ] - + for text in test_passages: embedding = embed_model.get_text_embedding(text) passage = PydanticPassage( @@ -659,15 +1059,15 @@ def test_passage_vector_search(server: SyncServer, default_user, default_file, s organization_id=default_user.organization_id, agent_id=sarah_agent.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, - embedding=embedding + embedding=embedding, ) created_passage = passage_manager.create_passage(passage, default_user) passages.append(created_passage) assert passage_manager.size(actor=default_user) == len(passages) - + # Query vector similar to "cats" embedding query_key = "What's my favorite color?" - + # List passages with vector search results = passage_manager.list_passages( actor=default_user, @@ -677,11 +1077,11 @@ def test_passage_vector_search(server: SyncServer, default_user, default_file, s embedding_config=DEFAULT_EMBEDDING_CONFIG, embed_query=True, ) - + # Verify results are ordered by similarity assert len(results) == 3 assert results[0].text == "I like red" - assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes" + assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes" assert results[2].text == "blue shoes" @@ -1164,7 +1564,7 @@ def test_delete_block(server: SyncServer, default_user): # ====================================================================================================================== -# Source Manager Tests - Sources +# SourceManager Tests - Sources # ====================================================================================================================== def test_create_source(server: SyncServer, default_user): """Test creating a new source.""" @@ -1376,86 +1776,6 @@ def test_delete_file(server: SyncServer, default_user, default_source): assert len(files) == 0 -# ====================================================================================================================== -# AgentsTagsManager Tests -# ====================================================================================================================== -def test_add_tag_to_agent(server: SyncServer, sarah_agent, default_user): - # Add a tag to the agent - tag_name = "test_tag" - tag_association = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - - # Assert that the tag association was created correctly - assert tag_association.agent_id == sarah_agent.id - assert tag_association.tag == tag_name - - -def test_add_duplicate_tag_to_agent(server: SyncServer, sarah_agent, default_user): - # Add the same tag twice to the agent - tag_name = "test_tag" - first_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - duplicate_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - - # Assert that the second addition returns the existing tag without creating a duplicate - assert first_tag.agent_id == duplicate_tag.agent_id - assert first_tag.tag == duplicate_tag.tag - - # Get all the tags belonging to the agent - tags = server.agents_tags_manager.get_tags_for_agent(agent_id=sarah_agent.id, actor=default_user) - assert len(tags) == 1 - assert tags[0] == first_tag.tag - - -def test_delete_tag_from_agent(server: SyncServer, sarah_agent, default_user): - # Add a tag, then delete it - tag_name = "test_tag" - server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - - # Assert the tag was deleted - agent_tags = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user) - assert sarah_agent.id not in agent_tags - - -def test_delete_nonexistent_tag_from_agent(server: SyncServer, sarah_agent, default_user): - # Attempt to delete a tag that doesn't exist - tag_name = "nonexistent_tag" - with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{sarah_agent.id}'"): - server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - - -def test_delete_tag_from_nonexistent_agent(server: SyncServer, default_user): - # Attempt to delete a tag that doesn't exist - tag_name = "nonexistent_tag" - agent_id = "abc" - with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{agent_id}'"): - server.agents_tags_manager.delete_tag_from_agent(agent_id=agent_id, tag=tag_name, actor=default_user) - - -def test_get_agents_by_tag(server: SyncServer, sarah_agent, charles_agent, default_user, default_organization): - # Add a shared tag to multiple agents - tag_name = "shared_tag" - - # Add the same tag to both agents - server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) - server.agents_tags_manager.add_tag_to_agent(agent_id=charles_agent.id, tag=tag_name, actor=default_user) - - # Retrieve agents by tag - agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user) - - # Assert that both agents are returned for the tag - assert sarah_agent.id in agent_ids - assert charles_agent.id in agent_ids - assert len(agent_ids) == 2 - - # Delete tags from only sarah agent - server.agents_tags_manager.delete_all_tags_from_agent(agent_id=sarah_agent.id, actor=default_user) - agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user) - # Assert that both agents are returned for the tag - assert sarah_agent.id not in agent_ids - assert charles_agent.id in agent_ids - assert len(agent_ids) == 1 - - # ====================================================================================================================== # SandboxConfigManager Tests - Sandbox Configs # ====================================================================================================================== @@ -1605,205 +1925,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, assert retrieved_env_var.key == sandbox_env_var_fixture.key -# ====================================================================================================================== -# BlocksAgentsManager Tests -# ====================================================================================================================== -def test_add_block_to_agent(server, sarah_agent, default_user, default_block): - block_association = server.blocks_agents_manager.add_block_to_agent( - agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label - ) - - assert block_association.agent_id == sarah_agent.id - assert block_association.block_id == default_block.id - assert block_association.block_label == default_block.label - - -def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agent, default_user, default_block): - # Add the block - block_association = server.blocks_agents_manager.add_block_to_agent( - agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label - ) - assert block_association.block_label == default_block.label - - # Change the block label - new_label = "banana" - block = server.block_manager.update_block(block_id=default_block.id, block_update=BlockUpdate(label=new_label), actor=default_user) - assert block.label == new_label - - # Get the association - labels = server.blocks_agents_manager.list_block_labels_for_agent(agent_id=sarah_agent.id) - assert new_label in labels - assert default_block.label not in labels - - -@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") -def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user): - with pytest.raises(ForeignKeyConstraintViolationError): - server.blocks_agents_manager.add_block_to_agent( - agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label" - ) - - -def test_add_block_to_agent_duplicate_label(server, sarah_agent, default_user, default_block, other_block): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) - - with pytest.warns(UserWarning, match=f"Block label '{default_block.label}' already exists for agent '{sarah_agent.id}'"): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=default_block.label) - - -def test_remove_block_with_label_from_agent(server, sarah_agent, default_user, default_block): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) - - removed_block = server.blocks_agents_manager.remove_block_with_label_from_agent( - agent_id=sarah_agent.id, block_label=default_block.label - ) - - assert removed_block.block_label == default_block.label - assert removed_block.block_id == default_block.id - assert removed_block.agent_id == sarah_agent.id - - with pytest.raises(ValueError, match=f"Block label '{default_block.label}' not found for agent '{sarah_agent.id}'"): - server.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=sarah_agent.id, block_label=default_block.label) - - -def test_update_block_id_for_agent(server, sarah_agent, default_user, default_block, other_block): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) - - updated_block = server.blocks_agents_manager.update_block_id_for_agent( - agent_id=sarah_agent.id, block_label=default_block.label, new_block_id=other_block.id - ) - - assert updated_block.block_id == other_block.id - assert updated_block.block_label == default_block.label - assert updated_block.agent_id == sarah_agent.id - - -def test_list_block_ids_for_agent(server, sarah_agent, default_user, default_block, other_block): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=other_block.label) - - retrieved_block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=sarah_agent.id) - - assert set(retrieved_block_ids) == {default_block.id, other_block.id} - - -def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_user, default_block): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) - server.blocks_agents_manager.add_block_to_agent(agent_id=charles_agent.id, block_id=default_block.id, block_label=default_block.label) - - agent_ids = server.blocks_agents_manager.list_agent_ids_with_block(block_id=default_block.id) - - assert sarah_agent.id in agent_ids - assert charles_agent.id in agent_ids - assert len(agent_ids) == 2 - - -@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") -def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block): - block_manager = BlockManager() - block_manager.delete_block(block_id=default_block.id, actor=default_user) - - with pytest.raises(ForeignKeyConstraintViolationError): - server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) - - -# ====================================================================================================================== -# ToolsAgentsManager Tests -# ====================================================================================================================== -def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool): - tool_association = server.tools_agents_manager.add_tool_to_agent( - agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name - ) - - assert tool_association.agent_id == sarah_agent.id - assert tool_association.tool_id == print_tool.id - assert tool_association.tool_name == print_tool.name - - -def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool): - # Add the tool - tool_association = server.tools_agents_manager.add_tool_to_agent( - agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name - ) - assert tool_association.tool_name == print_tool.name - - # Change the tool name - new_name = "banana" - tool = server.tool_manager.update_tool_by_id(tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user) - assert tool.name == new_name - - # Get the association - names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id) - assert new_name in names - assert print_tool.name not in names - - -@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") -def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user): - with pytest.raises(ForeignKeyConstraintViolationError): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name") - - -def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - - with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name) - - -def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - - removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name) - - assert removed_tool.tool_name == print_tool.name - assert removed_tool.tool_id == print_tool.id - assert removed_tool.agent_id == sarah_agent.id - - with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"): - server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name) - - -def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name) - - retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id) - - assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id} - - -def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - - agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id) - - assert sarah_agent.id in agent_ids - assert charles_agent.id in agent_ids - assert len(agent_ids) == 2 - - -@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") -def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool): - tool_manager = ToolManager() - tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user) - - with pytest.raises(ForeignKeyConstraintViolationError): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - - -def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool): - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) - server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name) - - server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id) - - retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id) - - assert not retrieved_tool_ids - - # ====================================================================================================================== # JobManager Tests # ====================================================================================================================== diff --git a/tests/test_new_cli.py b/tests/test_new_cli.py deleted file mode 100644 index a41dbc29..00000000 --- a/tests/test_new_cli.py +++ /dev/null @@ -1,126 +0,0 @@ -# TODO: fix later - -# import os -# import random -# import string -# import unittest.mock -# -# import pytest -# -# from letta.cli.cli_config import add, delete, list -# from letta.config import LettaConfig -# from letta.credentials import LettaCredentials -# from tests.utils import create_config -# -# -# def _reset_config(): -# -# if os.getenv("OPENAI_API_KEY"): -# create_config("openai") -# credentials = LettaCredentials( -# openai_key=os.getenv("OPENAI_API_KEY"), -# ) -# else: # hosted -# create_config("letta_hosted") -# credentials = LettaCredentials() -# -# config = LettaConfig.load() -# config.save() -# credentials.save() -# print("_reset_config :: ", config.config_path) -# -# -# @pytest.mark.skip(reason="This is a helper function.") -# def generate_random_string(length): -# characters = string.ascii_letters + string.digits -# random_string = "".join(random.choices(characters, k=length)) -# return random_string -# -# -# @pytest.mark.skip(reason="Ensures LocalClient is used during testing.") -# def unset_env_variables(): -# server_url = os.environ.pop("MEMGPT_BASE_URL", None) -# token = os.environ.pop("MEMGPT_SERVER_PASS", None) -# return server_url, token -# -# -# @pytest.mark.skip(reason="Set env variables back to values before test.") -# def reset_env_variables(server_url, token): -# if server_url is not None: -# os.environ["MEMGPT_BASE_URL"] = server_url -# if token is not None: -# os.environ["MEMGPT_SERVER_PASS"] = token -# -# -# def test_crud_human(capsys): -# _reset_config() -# -# server_url, token = unset_env_variables() -# -# # Initialize values that won't interfere with existing ones -# human_1 = generate_random_string(16) -# text_1 = generate_random_string(32) -# human_2 = generate_random_string(16) -# text_2 = generate_random_string(32) -# text_3 = generate_random_string(32) -# -# # Add inital human -# add("human", human_1, text_1) -# -# # Expect inital human to be listed -# list("humans") -# captured = capsys.readouterr() -# output = captured.out[captured.out.find(human_1) :] -# -# assert human_1 in output -# assert text_1 in output -# -# # Add second human -# add("human", human_2, text_2) -# -# # Expect to see second human -# list("humans") -# captured = capsys.readouterr() -# output = captured.out[captured.out.find(human_1) :] -# -# assert human_1 in output -# assert text_1 in output -# assert human_2 in output -# assert text_2 in output -# -# with unittest.mock.patch("questionary.confirm") as mock_confirm: -# mock_confirm.return_value.ask.return_value = True -# -# # Update second human -# add("human", human_2, text_3) -# -# # Expect to see update text -# list("humans") -# captured = capsys.readouterr() -# output = captured.out[captured.out.find(human_1) :] -# -# assert human_1 in output -# assert text_1 in output -# assert human_2 in output -# assert output.count(human_2) == 1 -# assert text_3 in output -# assert text_2 not in output -# -# # Delete second human -# delete("human", human_2) -# -# # Expect second human to be deleted -# list("humans") -# captured = capsys.readouterr() -# output = captured.out[captured.out.find(human_1) :] -# -# assert human_1 in output -# assert text_1 in output -# assert human_2 not in output -# assert text_2 not in output -# -# # Clean up -# delete("human", human_1) -# -# reset_env_variables(server_url, token) -# diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py deleted file mode 100644 index 1fd3c6d4..00000000 --- a/tests/test_openai_client.py +++ /dev/null @@ -1,93 +0,0 @@ -from logging import getLogger - -from openai import APIConnectionError, OpenAI - -logger = getLogger(__name__) - - -def test_openai_assistant(): - client = OpenAI(base_url="http://127.0.0.1:8080/v1") - # create assistant - try: - assistant = client.beta.assistants.create( - name="Math Tutor", - instructions="You are a personal math tutor. Write and run code to answer math questions.", - # tools=[{"type": "code_interpreter"}], - model="gpt-4-turbo-preview", - ) - except APIConnectionError as e: - logger.error("Connection issue with localhost openai stub: %s", e) - return - # create thread - thread = client.beta.threads.create() - - message = client.beta.threads.messages.create( - thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?" - ) - - run = client.beta.threads.runs.create( - thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account." - ) - - # run = client.beta.threads.runs.create( - # thread_id=thread.id, - # assistant_id=assistant.id, - # model="gpt-4-turbo-preview", - # instructions="New instructions that override the Assistant instructions", - # tools=[{"type": "code_interpreter"}, {"type": "retrieval"}] - # ) - - # Store the run ID - run_id = run.id - print(run_id) - - # NOTE: Letta does not support polling yet, so run status is always "completed" - # Retrieve all messages from the thread - messages = client.beta.threads.messages.list(thread_id=thread.id) - - # Print all messages from the thread - for msg in messages.messages: - role = msg["role"] - content = msg["content"][0] - print(f"{role.capitalize()}: {content}") - - # TODO: add once polling works - ## Polling for the run status - # while True: - # # Retrieve the run status - # run_status = client.beta.threads.runs.retrieve( - # thread_id=thread.id, - # run_id=run_id - # ) - - # # Check and print the step details - # run_steps = client.beta.threads.runs.steps.list( - # thread_id=thread.id, - # run_id=run_id - # ) - # for step in run_steps.data: - # if step.type == 'tool_calls': - # print(f"Tool {step.type} invoked.") - - # # If step involves code execution, print the code - # if step.type == 'code_interpreter': - # print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}") - - # if run_status.status == 'completed': - # # Retrieve all messages from the thread - # messages = client.beta.threads.messages.list( - # thread_id=thread.id - # ) - - # # Print all messages from the thread - # for msg in messages.data: - # role = msg.role - # content = msg.content[0].text.value - # print(f"{role.capitalize()}: {content}") - # break # Exit the polling loop since the run is complete - # elif run_status.status in ['queued', 'in_progress']: - # print(f'{run_status.status.capitalize()}... Please wait.') - # time.sleep(1.5) # Wait before checking again - # else: - # print(f"Run status: {run_status.status}") - # break # Exit the polling loop if the status is neither 'in_progress' nor 'completed' diff --git a/tests/test_persistence.py b/tests/test_persistence.py deleted file mode 100644 index 9b86f2b2..00000000 --- a/tests/test_persistence.py +++ /dev/null @@ -1,52 +0,0 @@ -# test state saving between client session -# TODO: update this test with correct imports - - -# def test_save_load(client): -# """Test that state is being persisted correctly after an /exit -# -# Create a new agent, and request a message -# -# Then trigger -# """ -# assert client is not None, "Run create_agent test first" -# assert test_agent_state is not None, "Run create_agent test first" -# assert test_agent_state_post_message is not None, "Run test_user_message test first" -# -# # Create a new client (not thread safe), and load the same agent -# # The agent state inside should correspond to the initial state pre-message -# if os.getenv("OPENAI_API_KEY"): -# client2 = Letta(quickstart="openai", user_id=test_user_id) -# else: -# client2 = Letta(quickstart="letta_hosted", user_id=test_user_id) -# print(f"\n\n[3] CREATING CLIENT2, LOADING AGENT {test_agent_state.id}!") -# client2_agent_obj = client2.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id) -# client2_agent_state = client2_agent_obj.update_state() -# print(f"[3] LOADED AGENT! AGENT {client2_agent_state.id}\n\tmessages={client2_agent_state.state['messages']}") -# -# # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}" -# def check_state_equivalence(state_1, state_2): -# """Helper function that checks the equivalence of two AgentState objects""" -# assert state_1.keys() == state_2.keys(), f"{state_1.keys()}\n{state_2.keys}" -# for k, v1 in state_1.items(): -# v2 = state_2[k] -# if isinstance(v1, LLMConfig) or isinstance(v1, EmbeddingConfig): -# assert vars(v1) == vars(v2), f"{vars(v1)}\n{vars(v2)}" -# else: -# assert v1 == v2, f"{v1}\n{v2}" -# -# check_state_equivalence(vars(test_agent_state), vars(client2_agent_state)) -# -# # Now, write out the save from the original client -# # This should persist the test message into the agent state -# client.save() -# -# if os.getenv("OPENAI_API_KEY"): -# client3 = Letta(quickstart="openai", user_id=test_user_id) -# else: -# client3 = Letta(quickstart="letta_hosted", user_id=test_user_id) -# client3_agent_obj = client3.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id) -# client3_agent_state = client3_agent_obj.update_state() -# -# check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state)) -# diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py deleted file mode 100644 index d4eaec0c..00000000 --- a/tests/test_schema_generator.py +++ /dev/null @@ -1,62 +0,0 @@ -from letta.functions.schema_generator import generate_schema - - -def send_message(self, message: str): - """ - Sends a message to the human user. - - Args: - message (str): Message contents. All unicode (including emojis) are supported. - - Returns: - Optional[str]: None is always returned as this function does not produce a response. - """ - return None - - -def send_message_missing_types(self, message): - """ - Sends a message to the human user. - - Args: - message (str): Message contents. All unicode (including emojis) are supported. - - Returns: - Optional[str]: None is always returned as this function does not produce a response. - """ - return None - - -def send_message_missing_docstring(self, message: str): - return None - - -def test_schema_generator(): - # Check that a basic function schema converts correctly - correct_schema = { - "name": "send_message", - "description": "Sends a message to the human user.", - "parameters": { - "type": "object", - "properties": {"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}}, - "required": ["message"], - }, - } - generated_schema = generate_schema(send_message) - print(f"\n\nreference_schema={correct_schema}") - print(f"\n\ngenerated_schema={generated_schema}") - assert correct_schema == generated_schema - - # Check that missing types results in an error - try: - _ = generate_schema(send_message_missing_types) - assert False - except: - pass - - # Check that missing docstring results in an error - try: - _ = generate_schema(send_message_missing_docstring) - assert False - except: - pass diff --git a/tests/test_server.py b/tests/test_server.py index 482fe894..09dfb94c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -19,8 +19,6 @@ from letta.schemas.letta_message import ( ) from letta.schemas.user import User -from .test_managers import DEFAULT_EMBEDDING_CONFIG - utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent @@ -266,6 +264,7 @@ Lise, young Bolkónski's wife, this very evening, and perhaps the thing can be arranged. It shall be on your family's behalf that I'll start my apprenticeship as old maid.""" + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() @@ -302,42 +301,66 @@ def user_id(server, org_id): @pytest.fixture(scope="module") -def agent_id(server, user_id): +def base_tools(server, user_id): + actor = server.user_manager.get_user_or_default(user_id) + tools = [] + for tool_name in BASE_TOOLS: + tools.append(server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)) + + yield tools + + +@pytest.fixture(scope="module") +def base_memory_tools(server, user_id): + actor = server.user_manager.get_user_or_default(user_id) + tools = [] + for tool_name in BASE_MEMORY_TOOLS: + tools.append(server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)) + + yield tools + + +@pytest.fixture(scope="module") +def agent_id(server, user_id, base_tools): # create agent + actor = server.user_manager.get_user_or_default(user_id) agent_state = server.create_agent( request=CreateAgent( name="test_agent", - tools=BASE_TOOLS, + tool_ids=[t.id for t in base_tools], memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), - actor=server.get_user_or_default(user_id), + actor=actor, ) print(f"Created agent\n{agent_state}") yield agent_state.id # cleanup - server.delete_agent(user_id, agent_state.id) + server.agent_manager.delete_agent(agent_state.id, actor=actor) + @pytest.fixture(scope="module") -def other_agent_id(server, user_id): +def other_agent_id(server, user_id, base_tools): # create agent + actor = server.user_manager.get_user_or_default(user_id) agent_state = server.create_agent( request=CreateAgent( name="test_agent_other", - tools=BASE_TOOLS, + tool_ids=[t.id for t in base_tools], memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), - actor=server.get_user_or_default(user_id), + actor=actor, ) print(f"Created agent\n{agent_state}") yield agent_state.id # cleanup - server.delete_agent(user_id, agent_state.id) + server.agent_manager.delete_agent(agent_state.id, actor=actor) + def test_error_on_nonexistent_agent(server, user_id, agent_id): try: @@ -416,6 +439,7 @@ def test_user_message(server, user_id, agent_id): @pytest.mark.order(5) def test_get_recall_memory(server, org_id, user_id, agent_id): # test recall memory cursor pagination + actor = server.user_manager.get_user_or_default(user_id=user_id) messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2) cursor1 = messages_1[-1].id messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000) @@ -427,7 +451,9 @@ def test_get_recall_memory(server, org_id, user_id, agent_id): assert len(messages_4) == 1 # test in-context message ids - in_context_ids = server.get_in_context_message_ids(agent_id=agent_id) + # in_context_ids = server.get_in_context_message_ids(agent_id=agent_id) + in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids + message_ids = [m.id for m in messages_3] for message_id in in_context_ids: assert message_id in message_ids, f"{message_id} not in {message_ids}" @@ -437,10 +463,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id): def test_get_archival_memory(server, user_id, agent_id): # test archival memory cursor pagination user = server.user_manager.get_user_by_id(user_id=user_id) - + # List latest 2 passages passages_1 = server.passage_manager.list_passages( - actor=user, agent_id=agent_id, ascending=False, limit=2, + actor=user, + agent_id=agent_id, + ascending=False, + limit=2, ) assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" @@ -483,12 +512,13 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): - "rewrite" replaces the text of the last assistant message - "retry" retries the last assistant message """ + actor = server.user_manager.get_user_or_default(user_id) # Send an initial message server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") # Grab the raw Agent object - letta_agent = server.load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id, actor=actor) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -496,10 +526,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): # Try "rethink" new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?" assert last_agent_message.text is not None and last_agent_message.text != new_thought - server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought) + server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought, actor=actor) # Grab the agent object again (make sure it's live) - letta_agent = server.load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id, actor=actor) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -513,10 +543,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): assert "message" in args_json and args_json["message"] is not None and args_json["message"] != "" new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?" - server.rewrite_agent_message(agent_id=agent_id, new_text=new_text) + server.rewrite_agent_message(agent_id=agent_id, new_text=new_text, actor=actor) # Grab the agent object again (make sure it's live) - letta_agent = server.load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id, actor=actor) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -524,10 +554,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text # Try retry - server.retry_agent_message(agent_id=agent_id) + server.retry_agent_message(agent_id=agent_id, actor=actor) # Grab the agent object again (make sure it's live) - letta_agent = server.load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id, actor=actor) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -581,33 +611,6 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: ) -def test_load_agent_with_nonexistent_tool_names_does_not_error(server: SyncServer, user_id: str): - fake_tool_name = "blahblahblah" - tools = BASE_TOOLS + [fake_tool_name] - agent_state = server.create_agent( - request=CreateAgent( - name="nonexistent_tools_agent", - tools=tools, - memory_blocks=[], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - ), - actor=server.get_user_or_default(user_id), - ) - - # Check that the tools in agent_state do NOT include the fake name - assert fake_tool_name not in agent_state.tool_names - assert set(BASE_TOOLS).issubset(set(agent_state.tool_names)) - - # Load the agent from the database and check that it doesn't error / tools are correct - saved_tools = server.get_tools_from_agent(agent_id=agent_state.id, user_id=user_id) - assert fake_tool_name not in agent_state.tool_names - assert set(BASE_TOOLS).issubset(set(agent_state.tool_names)) - - # cleanup - server.delete_agent(user_id, agent_state.id) - - def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): agent_state = server.create_agent( request=CreateAgent( @@ -616,14 +619,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), - actor=server.get_user_or_default(user_id), + actor=server.user_manager.get_user_or_default(user_id), ) # create another user in the same org another_user = server.user_manager.create_user(User(organization_id=org_id, name="another")) # test that another user in the same org can delete the agent - server.delete_agent(another_user.id, agent_state.id) + server.agent_manager.delete_agent(agent_state.id, actor=another_user) def _test_get_messages_letta_format( @@ -887,14 +890,14 @@ def test_composio_client_simple(server): assert len(actions) > 0 -def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none): +def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" - + actor = server.user_manager.get_user_or_default(user_id) # create agent agent_state = server.create_agent( request=CreateAgent( name="memory_rebuild_test_agent", - tools=BASE_TOOLS + BASE_MEMORY_TOOLS, + tool_ids=[t.id for t in base_tools + base_memory_tools], memory_blocks=[ CreateBlock(label="human", value="The human's name is Bob."), CreateBlock(label="persona", value="My name is Alice."), @@ -902,7 +905,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), - actor=server.get_user_or_default(user_id), + actor=actor, ) print(f"Created agent\n{agent_state}") @@ -929,31 +932,28 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none): try: # At this stage, there should only be 1 system message inside of recall storage num_system_messages, all_messages = count_system_messages_in_recall() - # assert num_system_messages == 1, (num_system_messages, all_messages) - assert num_system_messages == 2, (num_system_messages, all_messages) + assert num_system_messages == 1, (num_system_messages, all_messages) # Assuming core memory append actually ran correctly, at this point there should be 2 messages server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory") - # At this stage, there should only be 1 system message inside of recall storage + # At this stage, there should be 2 system message inside of recall storage num_system_messages, all_messages = count_system_messages_in_recall() - # assert num_system_messages == 2, (num_system_messages, all_messages) - assert num_system_messages == 3, (num_system_messages, all_messages) + assert num_system_messages == 2, (num_system_messages, all_messages) # Run server.load_agent, and make sure that the number of system messages is still 2 - server.load_agent(agent_id=agent_state.id) + server.load_agent(agent_id=agent_state.id, actor=actor) num_system_messages, all_messages = count_system_messages_in_recall() - # assert num_system_messages == 2, (num_system_messages, all_messages) - assert num_system_messages == 3, (num_system_messages, all_messages) + assert num_system_messages == 2, (num_system_messages, all_messages) finally: # cleanup - server.delete_agent(user_id, agent_state.id) + server.agent_manager.delete_agent(agent_state.id, actor=actor) def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path): - user = server.get_user_or_default(user_id) + actor = server.user_manager.get_user_or_default(user_id) # Create a source source = server.source_manager.create_source( @@ -962,7 +962,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot embedding_config=EmbeddingConfig.default_config(provider="openai"), created_by_id=user_id, ), - actor=user + actor=actor, ) # Create a test file with some content @@ -971,11 +971,10 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot test_file.write_text(test_content) # Attach source to agent first - agent = server.load_agent(agent_id=agent_id) - agent.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms) + server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor) # Get initial passage count - initial_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id) + initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id) assert initial_passage_count == 0 # Create a job for loading the first file @@ -984,7 +983,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot user_id=user_id, metadata_={"type": "embedding", "filename": test_file.name, "source_id": source.id}, ), - actor=user + actor=actor, ) # Load the first file to source @@ -992,17 +991,17 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot source_id=source.id, file_path=str(test_file), job_id=job.id, - actor=user, + actor=actor, ) # Verify job completed successfully - job = server.job_manager.get_job_by_id(job_id=job.id, actor=user) + job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor) assert job.status == "completed" - assert job.metadata_["num_passages"] == 1 + assert job.metadata_["num_passages"] == 1 assert job.metadata_["num_documents"] == 1 # Verify passages were added - first_file_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id) + first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id) assert first_file_passage_count > initial_passage_count # Create a second test file with different content @@ -1015,7 +1014,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot user_id=user_id, metadata_={"type": "embedding", "filename": test_file2.name, "source_id": source.id}, ), - actor=user + actor=actor, ) # Load the second file to source @@ -1023,22 +1022,22 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot source_id=source.id, file_path=str(test_file2), job_id=job2.id, - actor=user, + actor=actor, ) # Verify second job completed successfully - job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=user) + job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor) assert job2.status == "completed" assert job2.metadata_["num_passages"] >= 10 assert job2.metadata_["num_documents"] == 1 # Verify passages were appended (not replaced) - final_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id) + final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id) assert final_passage_count > first_file_passage_count # Verify both old and new content is searchable passages = server.passage_manager.list_passages( - actor=user, + actor=actor, agent_id=agent_id, source_id=source.id, query_text="what does Timber like to eat", diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 4bf180e1..89968413 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -33,7 +33,7 @@ def create_test_agent(): ) global agent_obj - agent_obj = client.server.load_agent(agent_id=agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) def test_summarize_messages_inplace(mock_e2b_api_key_none): @@ -74,7 +74,7 @@ def test_summarize_messages_inplace(mock_e2b_api_key_none): print(f"test_summarize: response={response}") # reload agent object - agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user) agent_obj.summarize_messages_inplace() print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}") @@ -121,7 +121,7 @@ def test_auto_summarize(mock_e2b_api_key_none): # check if the summarize message is inside the messages assert isinstance(client, LocalClient), "Test only works with LocalClient" - agent_obj = client.server.load_agent(agent_id=agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) print("SUMMARY", summarize_message_exists(agent_obj._messages)) if summarize_message_exists(agent_obj._messages): break diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py index 8f9d9972..d82bbc11 100644 --- a/tests/test_v1_routes.py +++ b/tests/test_v1_routes.py @@ -169,7 +169,7 @@ def configure_mock_sync_server(mock_sync_server): mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key] # Mock user retrieval - mock_sync_server.get_user_or_default.return_value = Mock() # Provide additional attributes if needed + mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed # ====================================================================================================================== @@ -182,7 +182,7 @@ def test_delete_tool(client, mock_sync_server, add_integers_tool): assert response.status_code == 200 mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with( - tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value + tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) @@ -195,7 +195,7 @@ def test_get_tool(client, mock_sync_server, add_integers_tool): assert response.json()["id"] == add_integers_tool.id assert response.json()["source_code"] == add_integers_tool.source_code mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with( - tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value + tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) @@ -216,7 +216,7 @@ def test_get_tool_id(client, mock_sync_server, add_integers_tool): assert response.status_code == 200 assert response.json() == add_integers_tool.id mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with( - tool_name=add_integers_tool.name, actor=mock_sync_server.get_user_or_default.return_value + tool_name=add_integers_tool.name, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) @@ -268,7 +268,7 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer assert response.status_code == 200 assert response.json()["id"] == add_integers_tool.id mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with( - tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.get_user_or_default.return_value + tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) @@ -280,7 +280,9 @@ def test_add_base_tools(client, mock_sync_server, add_integers_tool): assert response.status_code == 200 assert len(response.json()) == 1 assert response.json()[0]["id"] == add_integers_tool.id - mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(actor=mock_sync_server.get_user_or_default.return_value) + mock_sync_server.tool_manager.add_base_tools.assert_called_once_with( + actor=mock_sync_server.user_manager.get_user_or_default.return_value + ) def test_list_composio_apps(client, mock_sync_server, composio_apps): diff --git a/tests/test_vector_embeddings.py b/tests/test_vector_embeddings.py index 0ad25071..e65e6b9b 100644 --- a/tests/test_vector_embeddings.py +++ b/tests/test_vector_embeddings.py @@ -1,42 +1,39 @@ import numpy as np -import sqlite3 -import base64 -from numpy.testing import assert_array_almost_equal -import pytest +from letta.orm.sqlalchemy_base import adapt_array +from letta.orm.sqlite_functions import convert_array, verify_embedding_dimension -from letta.orm.sqlalchemy_base import adapt_array, convert_array -from letta.orm.sqlite_functions import verify_embedding_dimension def test_vector_conversions(): """Test the vector conversion functions""" # Create test data original = np.random.random(4096).astype(np.float32) print(f"Original shape: {original.shape}") - + # Test full conversion cycle encoded = adapt_array(original) print(f"Encoded type: {type(encoded)}") print(f"Encoded length: {len(encoded)}") - + decoded = convert_array(encoded) print(f"Decoded shape: {decoded.shape}") print(f"Dimension verification: {verify_embedding_dimension(decoded)}") - + # Verify data integrity np.testing.assert_array_almost_equal(original, decoded) print("✓ Data integrity verified") - + # Test with a list list_data = original.tolist() encoded_list = adapt_array(list_data) decoded_list = convert_array(encoded_list) np.testing.assert_array_almost_equal(original, decoded_list) print("✓ List conversion verified") - + # Test None handling assert adapt_array(None) is None assert convert_array(None) is None print("✓ None handling verified") -# Run the tests \ No newline at end of file + +# Run the tests