diff --git a/alembic/versions/f6cd5a1e519d_add_embedding_config_field_to_archives_.py b/alembic/versions/f6cd5a1e519d_add_embedding_config_field_to_archives_.py new file mode 100644 index 00000000..2cffe962 --- /dev/null +++ b/alembic/versions/f6cd5a1e519d_add_embedding_config_field_to_archives_.py @@ -0,0 +1,83 @@ +"""Add embedding config field to Archives table + +Revision ID: f6cd5a1e519d +Revises: c6c43222e2de +Create Date: 2025-10-23 16:33:53.661122 + +""" + +import json +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import text + +import letta.orm +from alembic import op +from letta.schemas.embedding_config import EmbeddingConfig + +# revision identifiers, used by Alembic. +revision: str = "f6cd5a1e519d" +down_revision: Union[str, None] = "c6c43222e2de" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # step 1: add column as nullable + op.add_column("archives", sa.Column("embedding_config", letta.orm.custom_columns.EmbeddingConfigColumn(), nullable=True)) + + # step 2: backfill existing archives with embedding configs in batches + connection = op.get_bind() + + # default embedding config for archives without passages + default_config = EmbeddingConfig.default_config(model_name="letta") + default_embedding_config = default_config.model_dump() + + batch_size = 100 + processed = 0 + + # process in batches until no more archives need backfilling + while True: + archives = connection.execute( + text("SELECT id FROM archives WHERE embedding_config IS NULL LIMIT :batch_size"), {"batch_size": batch_size} + ).fetchall() + + if not archives: + break + + for archive in archives: + archive_id = archive[0] + + # check if archive has passages + first_passage = connection.execute( + text("SELECT embedding_config FROM archival_passages WHERE archive_id = :archive_id AND is_deleted = FALSE LIMIT 1"), + {"archive_id": archive_id}, + ).fetchone() + + if first_passage and first_passage[0]: + embedding_config = first_passage[0] + else: + embedding_config = default_embedding_config + + # serialize the embedding config to JSON string for raw SQL + config_json = json.dumps(embedding_config) + + connection.execute( + text("UPDATE archives SET embedding_config = :config WHERE id = :archive_id"), + {"config": config_json, "archive_id": archive_id}, + ) + + processed += len(archives) + print(f"Backfilled {processed} archives so far...") + + connection.execute(text("COMMIT")) + + print(f"Backfill complete. Total archives processed: {processed}") + + # step 3: make column non-nullable + op.alter_column("archives", "embedding_config", nullable=False) + + +def downgrade() -> None: + op.drop_column("archives", "embedding_config") diff --git a/fern/openapi.json b/fern/openapi.json index 19cf5e71..56834bb6 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -18582,6 +18582,10 @@ "description": "The vector database provider used for this archive's passages", "default": "native" }, + "embedding_config": { + "$ref": "#/components/schemas/EmbeddingConfig", + "description": "Embedding configuration for passages in this archive" + }, "metadata": { "anyOf": [ { @@ -18605,7 +18609,12 @@ }, "additionalProperties": false, "type": "object", - "required": ["created_at", "name", "organization_id"], + "required": [ + "created_at", + "name", + "organization_id", + "embedding_config" + ], "title": "Archive", "description": "Representation of an archive - a collection of archival passages that can be shared between agents.\n\nParameters:\n id (str): The unique identifier of the archive.\n name (str): The name of the archive.\n description (str): A description of the archive.\n organization_id (str): The organization this archive belongs to.\n created_at (datetime): The creation date of the archive.\n metadata (dict): Additional metadata for the archive." }, @@ -18615,6 +18624,10 @@ "type": "string", "title": "Name" }, + "embedding_config": { + "$ref": "#/components/schemas/EmbeddingConfig", + "description": "Embedding configuration for the archive" + }, "description": { "anyOf": [ { @@ -18628,7 +18641,7 @@ } }, "type": "object", - "required": ["name"], + "required": ["name", "embedding_config"], "title": "ArchiveCreateRequest", "description": "Request model for creating an archive.\n\nIntentionally excludes vector_db_provider. These are derived internally (vector DB provider from env)." }, diff --git a/letta/orm/archive.py b/letta/orm/archive.py index 75f36906..16e0fddf 100644 --- a/letta/orm/archive.py +++ b/letta/orm/archive.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional from sqlalchemy import JSON, Enum, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.orm.custom_columns import EmbeddingConfigColumn from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.archive import Archive as PydanticArchive @@ -45,6 +46,9 @@ class Archive(SqlalchemyBase, OrganizationMixin): default=VectorDBProvider.NATIVE, doc="The vector database provider used for this archive's passages", ) + embedding_config: Mapped[dict] = mapped_column( + EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive" + ) metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive") _vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace") diff --git a/letta/schemas/archive.py b/letta/schemas/archive.py index cd8e2ac0..e9a54a6c 100644 --- a/letta/schemas/archive.py +++ b/letta/schemas/archive.py @@ -3,6 +3,7 @@ from typing import Dict, Optional from pydantic import Field +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.letta_base import OrmMetadataBase @@ -16,6 +17,7 @@ class ArchiveBase(OrmMetadataBase): vector_db_provider: VectorDBProvider = Field( default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages" ) + embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive") metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata") diff --git a/letta/server/rest_api/routers/v1/archives.py b/letta/server/rest_api/routers/v1/archives.py index 3f3554c8..83bf85ab 100644 --- a/letta/server/rest_api/routers/v1/archives.py +++ b/letta/server/rest_api/routers/v1/archives.py @@ -1,11 +1,14 @@ +from datetime import datetime from typing import List, Literal, Optional from fastapi import APIRouter, Body, Depends, Query -from pydantic import BaseModel +from pydantic import BaseModel, Field from letta import AgentState from letta.schemas.agent import AgentRelationships from letta.schemas.archive import Archive as PydanticArchive, ArchiveBase +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.passage import Passage as PydanticPassage from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.server import SyncServer from letta.validators import AgentId, ArchiveId, PassageId @@ -20,6 +23,7 @@ class ArchiveCreateRequest(BaseModel): """ name: str + embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for the archive") description: Optional[str] = None @@ -45,6 +49,7 @@ async def create_archive( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) return await server.archive_manager.create_archive_async( name=archive.name, + embedding_config=archive.embedding_config, description=archive.description, actor=actor, ) diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index d4c622f9..7f477d3f 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -10,6 +10,7 @@ from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.archive import Archive as PydanticArchive +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -28,6 +29,7 @@ class ArchiveManager: async def create_archive_async( self, name: str, + embedding_config: EmbeddingConfig, description: Optional[str] = None, actor: PydanticUser = None, ) -> PydanticArchive: @@ -42,6 +44,7 @@ class ArchiveManager: description=description, organization_id=actor.organization_id, vector_db_provider=vector_db_provider, + embedding_config=embedding_config, ) await archive.create_async(session, actor=actor) return archive.to_pydantic() @@ -299,11 +302,9 @@ class ArchiveManager: @enforce_types @trace_method - @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_or_create_default_archive_for_agent_async( self, - agent_id: str, - agent_name: Optional[str] = None, + agent_state: PydanticAgentState, actor: PydanticUser = None, ) -> PydanticArchive: """Get the agent's default archive, creating one if it doesn't exist.""" @@ -315,14 +316,14 @@ class ArchiveManager: agent_manager = AgentManager() archive_ids = await agent_manager.get_agent_archive_ids_async( - agent_id=agent_id, + agent_id=agent_state.id, actor=actor, ) if archive_ids: # TODO: Remove this check once we support multiple archives per agent if len(archive_ids) > 1: - raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported") + raise ValueError(f"Agent {agent_state.id} has multiple archives, which is not yet supported") # Get the archive archive = await self.get_archive_by_id_async( archive_id=archive_ids[0], @@ -331,9 +332,10 @@ class ArchiveManager: return archive # Create a default archive for this agent - archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive" + archive_name = f"{agent_state.name}'s Archive" archive = await self.create_archive_async( name=archive_name, + embedding_config=agent_state.embedding_config, description="Default archive created automatically", actor=actor, ) @@ -341,7 +343,7 @@ class ArchiveManager: try: # Attach the agent to the archive as owner await self.attach_agent_to_archive_async( - agent_id=agent_id, + agent_id=agent_state.id, archive_id=archive.id, is_owner=True, actor=actor, @@ -350,12 +352,12 @@ class ArchiveManager: except IntegrityError: # race condition: another concurrent request already created and attached an archive # clean up the orphaned archive we just created - logger.info(f"Race condition detected for agent {agent_id}, cleaning up orphaned archive {archive.id}") + logger.info(f"Race condition detected for agent {agent_state.id}, cleaning up orphaned archive {archive.id}") await self.delete_archive_async(archive_id=archive.id, actor=actor) # fetch the existing archive that was created by the concurrent request archive_ids = await agent_manager.get_agent_archive_ids_async( - agent_id=agent_id, + agent_id=agent_state.id, actor=actor, ) if archive_ids: diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 5ddcff8a..c15b4e11 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -437,9 +437,7 @@ class PassageManager: ) # Get or create the default archive for the agent - archive = await self.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=agent_state.id, agent_name=agent_state.name, actor=actor - ) + archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor) text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size)) diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index f642e75e..efa8e515 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -189,7 +189,9 @@ def test_should_use_tpuf_with_settings(): @pytest.mark.asyncio async def test_archive_creation_with_tpuf_enabled(server, default_user, enable_turbopuffer): """Test that archives are created with correct vector_db_provider when TPUF is enabled""" - archive = await server.archive_manager.create_archive_async(name="Test Archive with TPUF", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="Test Archive with TPUF", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) assert archive.vector_db_provider == VectorDBProvider.TPUF # TODO: Add cleanup when delete_archive method is available @@ -197,7 +199,9 @@ async def test_archive_creation_with_tpuf_enabled(server, default_user, enable_t @pytest.mark.asyncio async def test_archive_creation_with_tpuf_disabled(server, default_user, disable_turbopuffer): """Test that archives default to NATIVE when TPUF is disabled""" - archive = await server.archive_manager.create_archive_async(name="Test Archive without TPUF", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="Test Archive without TPUF", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) assert archive.vector_db_provider == VectorDBProvider.NATIVE # TODO: Add cleanup when delete_archive method is available @@ -208,7 +212,9 @@ async def test_dual_write_and_query_with_real_tpuf(server, default_user, sarah_a """Test that passages are written to both SQL and Turbopuffer with real connection and can be queried""" # Create a TPUF-enabled archive - archive = await server.archive_manager.create_archive_async(name="Test TPUF Archive for Real Dual Write", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="Test TPUF Archive for Real Dual Write", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) assert archive.vector_db_provider == VectorDBProvider.TPUF # Attach the agent to the archive @@ -351,9 +357,7 @@ async def test_native_only_operations(server, default_user, sarah_agent, disable """Test that operations work correctly when using only native PostgreSQL""" # Create archive (should be NATIVE since turbopuffer is disabled) - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) assert archive.vector_db_provider == VectorDBProvider.NATIVE # Insert passages - should only write to SQL @@ -1833,7 +1837,9 @@ async def test_message_date_filtering_with_real_tpuf(enable_message_embedding, d async def test_archive_namespace_tracking(server, default_user, enable_turbopuffer): """Test that archive namespaces are properly tracked in database""" # Create an archive - archive = await server.archive_manager.create_archive_async(name="Test Archive for Namespace", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="Test Archive for Namespace", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) # Get namespace - should be generated and stored namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id) @@ -1854,7 +1860,9 @@ async def test_archive_namespace_tracking(server, default_user, enable_turbopuff async def test_namespace_consistency_with_tpuf_client(server, default_user, enable_turbopuffer): """Test that the namespace from managers matches what tpuf_client would generate""" # Create archive and agent - archive = await server.archive_manager.create_archive_async(name="Test Consistency Archive", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="Test Consistency Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) # Get namespace from manager archive_namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id) @@ -1875,14 +1883,18 @@ async def test_environment_namespace_variation(server, default_user): try: settings.environment = None - archive = await server.archive_manager.create_archive_async(name="No Env Archive", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="No Env Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) namespace_no_env = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id) assert namespace_no_env == f"archive_{archive.id}" # Test with environment settings.environment = "TESTING" - archive2 = await server.archive_manager.create_archive_async(name="With Env Archive", actor=default_user) + archive2 = await server.archive_manager.create_archive_async( + name="With Env Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) namespace_with_env = await server.archive_manager.get_or_set_vector_db_namespace_async(archive2.id) assert namespace_with_env == f"archive_{archive2.id}_testing" diff --git a/tests/managers/conftest.py b/tests/managers/conftest.py index 6cc79435..27b5472f 100644 --- a/tests/managers/conftest.py +++ b/tests/managers/conftest.py @@ -395,7 +395,9 @@ async def comprehensive_test_agent_fixture(server: SyncServer, default_user, pri @pytest.fixture async def default_archive(server: SyncServer, default_user): """Create and return a default archive.""" - archive = await server.archive_manager.create_archive_async("test", actor=default_user) + archive = await server.archive_manager.create_archive_async( + "test", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user + ) yield archive @@ -403,9 +405,7 @@ async def default_archive(server: SyncServer, default_user): async def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): """Create an agent passage.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( diff --git a/tests/managers/test_archive_manager.py b/tests/managers/test_archive_manager.py index 3a44ec70..79e28365 100644 --- a/tests/managers/test_archive_manager.py +++ b/tests/managers/test_archive_manager.py @@ -54,7 +54,7 @@ from letta.orm import Base, Block from letta.orm.block_history import BlockHistory from letta.orm.errors import NoResultFound, UniqueConstraintViolationError from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel -from letta.schemas.agent import AgentRelationships, CreateAgent, UpdateAgent +from letta.schemas.agent import AgentRelationships, AgentState, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ( @@ -109,7 +109,10 @@ from tests.utils import random_string async def test_archive_manager_delete_archive_async(server: SyncServer, default_user): """Test the delete_archive_async function.""" archive = await server.archive_manager.create_archive_async( - name="test_archive_to_delete", description="This archive will be deleted", actor=default_user + name="test_archive_to_delete", + description="This archive will be deleted", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) retrieved_archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user) @@ -125,7 +128,10 @@ async def test_archive_manager_delete_archive_async(server: SyncServer, default_ async def test_archive_manager_get_agents_for_archive_async(server: SyncServer, default_user, sarah_agent): """Test getting all agents that have access to an archive.""" archive = await server.archive_manager.create_archive_async( - name="shared_archive", description="Archive shared by multiple agents", actor=default_user + name="shared_archive", + description="Archive shared by multiple agents", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) agent2 = await server.agent_manager.create_agent_async( @@ -187,7 +193,10 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau # First, create an archive that will be attached by a "concurrent" request concurrent_archive = await server.archive_manager.create_archive_async( - name=f"{agent.name}'s Archive", description="Default archive created automatically", actor=default_user + name=f"{agent.name}'s Archive", + description="Default archive created automatically", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) call_count = 0 @@ -206,9 +215,7 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau with patch.object(server.archive_manager, "create_archive_async", side_effect=track_create): with patch.object(server.archive_manager, "attach_agent_to_archive_async", side_effect=failing_attach): - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=agent.id, agent_name=agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent, actor=default_user) assert archive is not None assert archive.id == concurrent_archive.id # Should return the existing archive @@ -230,9 +237,7 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau @pytest.mark.asyncio async def test_archive_manager_get_agent_from_passage_async(server: SyncServer, default_user, sarah_agent): """Test getting the agent ID that owns a passage through its archive.""" - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( @@ -250,7 +255,7 @@ async def test_archive_manager_get_agent_from_passage_async(server: SyncServer, assert agent_id == sarah_agent.id orphan_archive = await server.archive_manager.create_archive_async( - name="orphan_archive", description="Archive with no agents", actor=default_user + name="orphan_archive", description="Archive with no agents", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user ) orphan_passage = await server.passage_manager.create_agent_passage_async( @@ -278,7 +283,7 @@ async def test_archive_manager_create_archive_async(server: SyncServer, default_ """Test creating a new archive with various parameters.""" # test creating with name and description archive = await server.archive_manager.create_archive_async( - name="test_archive_basic", description="Test archive description", actor=default_user + name="test_archive_basic", description="Test archive description", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user ) assert archive.name == "test_archive_basic" @@ -287,7 +292,9 @@ async def test_archive_manager_create_archive_async(server: SyncServer, default_ assert archive.id is not None # test creating without description - archive2 = await server.archive_manager.create_archive_async(name="test_archive_no_desc", actor=default_user) + archive2 = await server.archive_manager.create_archive_async( + name="test_archive_no_desc", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) assert archive2.name == "test_archive_no_desc" assert archive2.description is None @@ -303,7 +310,7 @@ async def test_archive_manager_get_archive_by_id_async(server: SyncServer, defau """Test retrieving an archive by its ID.""" # create an archive archive = await server.archive_manager.create_archive_async( - name="test_get_by_id", description="Archive to test get_by_id", actor=default_user + name="test_get_by_id", description="Archive to test get_by_id", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user ) # retrieve the archive @@ -327,7 +334,7 @@ async def test_archive_manager_update_archive_async(server: SyncServer, default_ """Test updating archive name and description.""" # create an archive archive = await server.archive_manager.create_archive_async( - name="original_name", description="original description", actor=default_user + name="original_name", description="original description", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user ) # update name only @@ -370,7 +377,7 @@ async def test_archive_manager_list_archives_async(server: SyncServer, default_u archives = [] for i in range(5): archive = await server.archive_manager.create_archive_async( - name=f"list_test_archive_{i}", description=f"Description {i}", actor=default_user + name=f"list_test_archive_{i}", description=f"Description {i}", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user ) archives.append(archive) @@ -413,8 +420,12 @@ async def test_archive_manager_list_archives_async(server: SyncServer, default_u async def test_archive_manager_attach_agent_to_archive_async(server: SyncServer, default_user, sarah_agent): """Test attaching agents to archives with ownership settings.""" # create archives - archive1 = await server.archive_manager.create_archive_async(name="archive_for_attachment_1", actor=default_user) - archive2 = await server.archive_manager.create_archive_async(name="archive_for_attachment_2", actor=default_user) + archive1 = await server.archive_manager.create_archive_async( + name="archive_for_attachment_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) + archive2 = await server.archive_manager.create_archive_async( + name="archive_for_attachment_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) # create another agent agent2 = await server.agent_manager.create_agent_async( @@ -466,7 +477,10 @@ async def test_archive_manager_detach_agent_from_archive_async(server: SyncServe """Test detaching agents from archives.""" # create archive and agents archive = await server.archive_manager.create_archive_async( - name="archive_for_detachment", description="Test archive for detachment", actor=default_user + name="archive_for_detachment", + description="Test archive for detachment", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) agent1 = await server.agent_manager.create_agent_async( @@ -539,7 +553,9 @@ async def test_archive_manager_detach_agent_from_archive_async(server: SyncServe async def test_archive_manager_attach_detach_idempotency(server: SyncServer, default_user): """Test that attach and detach operations are idempotent.""" # create archive and agent - archive = await server.archive_manager.create_archive_async(name="idempotency_test_archive", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="idempotency_test_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) agent = await server.agent_manager.create_agent_async( agent_create=CreateAgent( @@ -598,8 +614,12 @@ async def test_archive_manager_attach_detach_idempotency(server: SyncServer, def async def test_archive_manager_detach_with_multiple_archives(server: SyncServer, default_user): """Test detaching an agent from one archive doesn't affect others.""" # create two archives - archive1 = await server.archive_manager.create_archive_async(name="multi_archive_1", actor=default_user) - archive2 = await server.archive_manager.create_archive_async(name="multi_archive_2", actor=default_user) + archive1 = await server.archive_manager.create_archive_async( + name="multi_archive_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) + archive2 = await server.archive_manager.create_archive_async( + name="multi_archive_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) # create two agents agent1 = await server.agent_manager.create_agent_async( @@ -663,7 +683,9 @@ async def test_archive_manager_detach_with_multiple_archives(server: SyncServer, async def test_archive_manager_detach_deleted_agent(server: SyncServer, default_user): """Test behavior when detaching a deleted agent.""" # create archive - archive = await server.archive_manager.create_archive_async(name="test_deleted_agent_archive", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="test_deleted_agent_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) # create and attach agent agent = await server.agent_manager.create_agent_async( @@ -701,7 +723,10 @@ async def test_archive_manager_cascade_delete_on_archive_deletion(server: SyncSe """Test that deleting an archive cascades to delete relationships in archives_agents table.""" # create archive archive = await server.archive_manager.create_archive_async( - name="archive_to_be_deleted", description="This archive will be deleted to test CASCADE", actor=default_user + name="archive_to_be_deleted", + description="This archive will be deleted to test CASCADE", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) # create multiple agents and attach them to the archive @@ -777,7 +802,10 @@ async def test_archive_manager_list_agents_with_pagination(server: SyncServer, d """Test listing agents for an archive with pagination support.""" # create archive archive = await server.archive_manager.create_archive_async( - name="pagination_test_archive", description="Archive for testing pagination", actor=default_user + name="pagination_test_archive", + description="Archive for testing pagination", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) # create multiple agents @@ -855,7 +883,9 @@ async def test_archive_manager_get_default_archive_for_agent_async(server: SyncS assert archive is None # create and attach an archive - created_archive = await server.archive_manager.create_archive_async(name="default_archive", actor=default_user) + created_archive = await server.archive_manager.create_archive_async( + name="default_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) await server.archive_manager.attach_agent_to_archive_async( agent_id=agent.id, archive_id=created_archive.id, is_owner=True, actor=default_user @@ -875,7 +905,9 @@ async def test_archive_manager_get_default_archive_for_agent_async(server: SyncS async def test_archive_manager_get_or_set_vector_db_namespace_async(server: SyncServer, default_user): """Test getting or setting vector database namespace for an archive.""" # create an archive - archive = await server.archive_manager.create_archive_async(name="test_vector_namespace", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="test_vector_namespace", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) # get/set namespace for the first time namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive_id=archive.id) @@ -897,7 +929,10 @@ async def test_archive_manager_get_agents_with_include_parameter(server: SyncSer """Test getting agents for an archive with include parameter to load relationships.""" # create an archive archive = await server.archive_manager.create_archive_async( - name="test_include_archive", description="Test archive for include parameter", actor=default_user + name="test_include_archive", + description="Test archive for include parameter", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) # create agent without base tools (to avoid needing tools in test DB) @@ -961,7 +996,10 @@ async def test_archive_manager_delete_passage_from_archive_async(server: SyncSer """Test deleting a passage from an archive.""" # create archive archive = await server.archive_manager.create_archive_async( - name="test_passage_deletion_archive", description="Archive for testing passage deletion", actor=default_user + name="test_passage_deletion_archive", + description="Archive for testing passage deletion", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + actor=default_user, ) # create passages @@ -1015,8 +1053,12 @@ async def test_archive_manager_delete_passage_from_archive_async(server: SyncSer async def test_archive_manager_delete_passage_from_wrong_archive(server: SyncServer, default_user): """Test that deleting a passage from the wrong archive raises an error.""" # create two archives - archive1 = await server.archive_manager.create_archive_async(name="archive_1", actor=default_user) - archive2 = await server.archive_manager.create_archive_async(name="archive_2", actor=default_user) + archive1 = await server.archive_manager.create_archive_async( + name="archive_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) + archive2 = await server.archive_manager.create_archive_async( + name="archive_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) # create passage in archive1 passage = await server.passage_manager.create_agent_passage_async( @@ -1048,7 +1090,9 @@ async def test_archive_manager_delete_passage_from_wrong_archive(server: SyncSer async def test_archive_manager_delete_nonexistent_passage(server: SyncServer, default_user): """Test that deleting a non-existent passage raises an error.""" # create archive - archive = await server.archive_manager.create_archive_async(name="test_nonexistent_passage_archive", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="test_nonexistent_passage_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) # attempt to delete non-existent passage (use valid UUID4 format) fake_passage_id = f"passage-{uuid.uuid4()}" @@ -1065,7 +1109,9 @@ async def test_archive_manager_delete_nonexistent_passage(server: SyncServer, de async def test_archive_manager_delete_passage_from_nonexistent_archive(server: SyncServer, default_user): """Test that deleting a passage from a non-existent archive raises an error.""" # create archive and passage - archive = await server.archive_manager.create_archive_async(name="temp_archive", actor=default_user) + archive = await server.archive_manager.create_archive_async( + name="temp_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user + ) passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( diff --git a/tests/managers/test_passage_manager.py b/tests/managers/test_passage_manager.py index 477db117..5adb69a4 100644 --- a/tests/managers/test_passage_manager.py +++ b/tests/managers/test_passage_manager.py @@ -54,7 +54,7 @@ from letta.orm import Base, Block from letta.orm.block_history import BlockHistory from letta.orm.errors import NoResultFound, UniqueConstraintViolationError from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel -from letta.schemas.agent import CreateAgent, UpdateAgent +from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ( @@ -245,9 +245,7 @@ async def test_agent_list_passages_vector_search( embed_model = mock_embed_model # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create passages with known embeddings passages = [] @@ -429,9 +427,7 @@ async def test_passage_cascade_deletion( async def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test creating an agent passage using the new agent-specific method.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( @@ -494,9 +490,7 @@ async def test_create_agent_passage_validation(server: SyncServer, default_user, ) # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Should fail if source_id is present with pytest.raises(ValueError, match="Agent passage cannot have source_id"): @@ -530,9 +524,7 @@ async def test_create_source_passage_validation(server: SyncServer, default_user ) # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Should fail if archive_id is present with pytest.raises(ValueError, match="Source passage cannot have archive_id"): @@ -554,9 +546,7 @@ async def test_create_source_passage_validation(server: SyncServer, default_user async def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent): """Test retrieving an agent passage using the new agent-specific method.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create an agent passage passage = await server.passage_manager.create_agent_passage_async( @@ -608,9 +598,7 @@ async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sa """Test that trying to get the wrong passage type with specific methods fails.""" # Create an agent passage # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) agent_passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( @@ -650,9 +638,7 @@ async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sa async def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test updating an agent passage using the new agent-specific method.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create an agent passage passage = await server.passage_manager.create_agent_passage_async( @@ -724,9 +710,7 @@ async def test_update_source_passage_specific(server: SyncServer, default_user, async def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test deleting an agent passage using the new agent-specific method.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create an agent passage passage = await server.passage_manager.create_agent_passage_async( @@ -787,9 +771,7 @@ async def test_delete_source_passage_specific(server: SyncServer, default_user, async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent): """Test creating multiple agent passages using the new batch method.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) passages = [ PydanticPassage( @@ -846,9 +828,7 @@ async def test_agent_passage_size(server: SyncServer, default_user, sarah_agent) initial_size = await server.passage_manager.agent_passage_size_async(actor=default_user, agent_id=sarah_agent.id) # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create some agent passages for i in range(3): @@ -873,9 +853,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe from letta.schemas.enums import TagMatchMode # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create passages with different tag combinations test_passages = [ @@ -969,8 +947,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync # Test 2: Verify unique tags for archive archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, - agent_name=sarah_agent.name, + agent_state=sarah_agent, actor=default_user, ) @@ -1194,8 +1171,7 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age # Verify unique tags includes all special character tags archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, - agent_name=sarah_agent.name, + agent_state=sarah_agent, actor=default_user, ) @@ -1239,9 +1215,7 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent): """Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint.""" # Get or create default archive for the agent - archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user - ) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create test passages with various content and tags test_data = [