feat: Support embedding config on the archive [LET-5832] (#5714)

* Add embedding config field to archives

* Fix alembic script

* Simplify archive manager

* Fern autogen

* Fix failing tests

* Fix alembic
This commit is contained in:
Matthew Zhou
2025-10-23 17:37:58 -07:00
committed by Caren Thomas
parent c7c0d7507c
commit e7e86124f9
11 changed files with 242 additions and 103 deletions

View File

@@ -0,0 +1,83 @@
"""Add embedding config field to Archives table
Revision ID: f6cd5a1e519d
Revises: c6c43222e2de
Create Date: 2025-10-23 16:33:53.661122
"""
import json
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy import text
import letta.orm
from alembic import op
from letta.schemas.embedding_config import EmbeddingConfig
# revision identifiers, used by Alembic.
revision: str = "f6cd5a1e519d"
down_revision: Union[str, None] = "c6c43222e2de"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# step 1: add column as nullable
op.add_column("archives", sa.Column("embedding_config", letta.orm.custom_columns.EmbeddingConfigColumn(), nullable=True))
# step 2: backfill existing archives with embedding configs in batches
connection = op.get_bind()
# default embedding config for archives without passages
default_config = EmbeddingConfig.default_config(model_name="letta")
default_embedding_config = default_config.model_dump()
batch_size = 100
processed = 0
# process in batches until no more archives need backfilling
while True:
archives = connection.execute(
text("SELECT id FROM archives WHERE embedding_config IS NULL LIMIT :batch_size"), {"batch_size": batch_size}
).fetchall()
if not archives:
break
for archive in archives:
archive_id = archive[0]
# check if archive has passages
first_passage = connection.execute(
text("SELECT embedding_config FROM archival_passages WHERE archive_id = :archive_id AND is_deleted = FALSE LIMIT 1"),
{"archive_id": archive_id},
).fetchone()
if first_passage and first_passage[0]:
embedding_config = first_passage[0]
else:
embedding_config = default_embedding_config
# serialize the embedding config to JSON string for raw SQL
config_json = json.dumps(embedding_config)
connection.execute(
text("UPDATE archives SET embedding_config = :config WHERE id = :archive_id"),
{"config": config_json, "archive_id": archive_id},
)
processed += len(archives)
print(f"Backfilled {processed} archives so far...")
connection.execute(text("COMMIT"))
print(f"Backfill complete. Total archives processed: {processed}")
# step 3: make column non-nullable
op.alter_column("archives", "embedding_config", nullable=False)
def downgrade() -> None:
op.drop_column("archives", "embedding_config")

View File

@@ -18582,6 +18582,10 @@
"description": "The vector database provider used for this archive's passages",
"default": "native"
},
"embedding_config": {
"$ref": "#/components/schemas/EmbeddingConfig",
"description": "Embedding configuration for passages in this archive"
},
"metadata": {
"anyOf": [
{
@@ -18605,7 +18609,12 @@
},
"additionalProperties": false,
"type": "object",
"required": ["created_at", "name", "organization_id"],
"required": [
"created_at",
"name",
"organization_id",
"embedding_config"
],
"title": "Archive",
"description": "Representation of an archive - a collection of archival passages that can be shared between agents.\n\nParameters:\n id (str): The unique identifier of the archive.\n name (str): The name of the archive.\n description (str): A description of the archive.\n organization_id (str): The organization this archive belongs to.\n created_at (datetime): The creation date of the archive.\n metadata (dict): Additional metadata for the archive."
},
@@ -18615,6 +18624,10 @@
"type": "string",
"title": "Name"
},
"embedding_config": {
"$ref": "#/components/schemas/EmbeddingConfig",
"description": "Embedding configuration for the archive"
},
"description": {
"anyOf": [
{
@@ -18628,7 +18641,7 @@
}
},
"type": "object",
"required": ["name"],
"required": ["name", "embedding_config"],
"title": "ArchiveCreateRequest",
"description": "Request model for creating an archive.\n\nIntentionally excludes vector_db_provider. These are derived internally (vector DB provider from env)."
},

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Enum, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import EmbeddingConfigColumn
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.archive import Archive as PydanticArchive
@@ -45,6 +46,9 @@ class Archive(SqlalchemyBase, OrganizationMixin):
default=VectorDBProvider.NATIVE,
doc="The vector database provider used for this archive's passages",
)
embedding_config: Mapped[dict] = mapped_column(
EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive"
)
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
_vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace")

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
from pydantic import Field
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.letta_base import OrmMetadataBase
@@ -16,6 +17,7 @@ class ArchiveBase(OrmMetadataBase):
vector_db_provider: VectorDBProvider = Field(
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
)
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive")
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")

View File

@@ -1,11 +1,14 @@
from datetime import datetime
from typing import List, Literal, Optional
from fastapi import APIRouter, Body, Depends, Query
from pydantic import BaseModel
from pydantic import BaseModel, Field
from letta import AgentState
from letta.schemas.agent import AgentRelationships
from letta.schemas.archive import Archive as PydanticArchive, ArchiveBase
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.passage import Passage as PydanticPassage
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.server import SyncServer
from letta.validators import AgentId, ArchiveId, PassageId
@@ -20,6 +23,7 @@ class ArchiveCreateRequest(BaseModel):
"""
name: str
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for the archive")
description: Optional[str] = None
@@ -45,6 +49,7 @@ async def create_archive(
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.archive_manager.create_archive_async(
name=archive.name,
embedding_config=archive.embedding_config,
description=archive.description,
actor=actor,
)

View File

@@ -10,6 +10,7 @@ from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.archive import Archive as PydanticArchive
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
@@ -28,6 +29,7 @@ class ArchiveManager:
async def create_archive_async(
self,
name: str,
embedding_config: EmbeddingConfig,
description: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticArchive:
@@ -42,6 +44,7 @@ class ArchiveManager:
description=description,
organization_id=actor.organization_id,
vector_db_provider=vector_db_provider,
embedding_config=embedding_config,
)
await archive.create_async(session, actor=actor)
return archive.to_pydantic()
@@ -299,11 +302,9 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_or_create_default_archive_for_agent_async(
self,
agent_id: str,
agent_name: Optional[str] = None,
agent_state: PydanticAgentState,
actor: PydanticUser = None,
) -> PydanticArchive:
"""Get the agent's default archive, creating one if it doesn't exist."""
@@ -315,14 +316,14 @@ class ArchiveManager:
agent_manager = AgentManager()
archive_ids = await agent_manager.get_agent_archive_ids_async(
agent_id=agent_id,
agent_id=agent_state.id,
actor=actor,
)
if archive_ids:
# TODO: Remove this check once we support multiple archives per agent
if len(archive_ids) > 1:
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported")
raise ValueError(f"Agent {agent_state.id} has multiple archives, which is not yet supported")
# Get the archive
archive = await self.get_archive_by_id_async(
archive_id=archive_ids[0],
@@ -331,9 +332,10 @@ class ArchiveManager:
return archive
# Create a default archive for this agent
archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive"
archive_name = f"{agent_state.name}'s Archive"
archive = await self.create_archive_async(
name=archive_name,
embedding_config=agent_state.embedding_config,
description="Default archive created automatically",
actor=actor,
)
@@ -341,7 +343,7 @@ class ArchiveManager:
try:
# Attach the agent to the archive as owner
await self.attach_agent_to_archive_async(
agent_id=agent_id,
agent_id=agent_state.id,
archive_id=archive.id,
is_owner=True,
actor=actor,
@@ -350,12 +352,12 @@ class ArchiveManager:
except IntegrityError:
# race condition: another concurrent request already created and attached an archive
# clean up the orphaned archive we just created
logger.info(f"Race condition detected for agent {agent_id}, cleaning up orphaned archive {archive.id}")
logger.info(f"Race condition detected for agent {agent_state.id}, cleaning up orphaned archive {archive.id}")
await self.delete_archive_async(archive_id=archive.id, actor=actor)
# fetch the existing archive that was created by the concurrent request
archive_ids = await agent_manager.get_agent_archive_ids_async(
agent_id=agent_id,
agent_id=agent_state.id,
actor=actor,
)
if archive_ids:

View File

@@ -437,9 +437,7 @@ class PassageManager:
)
# Get or create the default archive for the agent
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=agent_state.id, agent_name=agent_state.name, actor=actor
)
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))

View File

@@ -189,7 +189,9 @@ def test_should_use_tpuf_with_settings():
@pytest.mark.asyncio
async def test_archive_creation_with_tpuf_enabled(server, default_user, enable_turbopuffer):
"""Test that archives are created with correct vector_db_provider when TPUF is enabled"""
archive = await server.archive_manager.create_archive_async(name="Test Archive with TPUF", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="Test Archive with TPUF", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
assert archive.vector_db_provider == VectorDBProvider.TPUF
# TODO: Add cleanup when delete_archive method is available
@@ -197,7 +199,9 @@ async def test_archive_creation_with_tpuf_enabled(server, default_user, enable_t
@pytest.mark.asyncio
async def test_archive_creation_with_tpuf_disabled(server, default_user, disable_turbopuffer):
"""Test that archives default to NATIVE when TPUF is disabled"""
archive = await server.archive_manager.create_archive_async(name="Test Archive without TPUF", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="Test Archive without TPUF", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
assert archive.vector_db_provider == VectorDBProvider.NATIVE
# TODO: Add cleanup when delete_archive method is available
@@ -208,7 +212,9 @@ async def test_dual_write_and_query_with_real_tpuf(server, default_user, sarah_a
"""Test that passages are written to both SQL and Turbopuffer with real connection and can be queried"""
# Create a TPUF-enabled archive
archive = await server.archive_manager.create_archive_async(name="Test TPUF Archive for Real Dual Write", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="Test TPUF Archive for Real Dual Write", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
assert archive.vector_db_provider == VectorDBProvider.TPUF
# Attach the agent to the archive
@@ -351,9 +357,7 @@ async def test_native_only_operations(server, default_user, sarah_agent, disable
"""Test that operations work correctly when using only native PostgreSQL"""
# Create archive (should be NATIVE since turbopuffer is disabled)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
assert archive.vector_db_provider == VectorDBProvider.NATIVE
# Insert passages - should only write to SQL
@@ -1833,7 +1837,9 @@ async def test_message_date_filtering_with_real_tpuf(enable_message_embedding, d
async def test_archive_namespace_tracking(server, default_user, enable_turbopuffer):
"""Test that archive namespaces are properly tracked in database"""
# Create an archive
archive = await server.archive_manager.create_archive_async(name="Test Archive for Namespace", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="Test Archive for Namespace", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
# Get namespace - should be generated and stored
namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id)
@@ -1854,7 +1860,9 @@ async def test_archive_namespace_tracking(server, default_user, enable_turbopuff
async def test_namespace_consistency_with_tpuf_client(server, default_user, enable_turbopuffer):
"""Test that the namespace from managers matches what tpuf_client would generate"""
# Create archive and agent
archive = await server.archive_manager.create_archive_async(name="Test Consistency Archive", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="Test Consistency Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
# Get namespace from manager
archive_namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id)
@@ -1875,14 +1883,18 @@ async def test_environment_namespace_variation(server, default_user):
try:
settings.environment = None
archive = await server.archive_manager.create_archive_async(name="No Env Archive", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="No Env Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
namespace_no_env = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id)
assert namespace_no_env == f"archive_{archive.id}"
# Test with environment
settings.environment = "TESTING"
archive2 = await server.archive_manager.create_archive_async(name="With Env Archive", actor=default_user)
archive2 = await server.archive_manager.create_archive_async(
name="With Env Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
namespace_with_env = await server.archive_manager.get_or_set_vector_db_namespace_async(archive2.id)
assert namespace_with_env == f"archive_{archive2.id}_testing"

View File

@@ -395,7 +395,9 @@ async def comprehensive_test_agent_fixture(server: SyncServer, default_user, pri
@pytest.fixture
async def default_archive(server: SyncServer, default_user):
"""Create and return a default archive."""
archive = await server.archive_manager.create_archive_async("test", actor=default_user)
archive = await server.archive_manager.create_archive_async(
"test", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
)
yield archive
@@ -403,9 +405,7 @@ async def default_archive(server: SyncServer, default_user):
async def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
"""Create an agent passage."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(

View File

@@ -54,7 +54,7 @@ from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.schemas.agent import AgentRelationships, CreateAgent, UpdateAgent
from letta.schemas.agent import AgentRelationships, AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import (
@@ -109,7 +109,10 @@ from tests.utils import random_string
async def test_archive_manager_delete_archive_async(server: SyncServer, default_user):
"""Test the delete_archive_async function."""
archive = await server.archive_manager.create_archive_async(
name="test_archive_to_delete", description="This archive will be deleted", actor=default_user
name="test_archive_to_delete",
description="This archive will be deleted",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
retrieved_archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user)
@@ -125,7 +128,10 @@ async def test_archive_manager_delete_archive_async(server: SyncServer, default_
async def test_archive_manager_get_agents_for_archive_async(server: SyncServer, default_user, sarah_agent):
"""Test getting all agents that have access to an archive."""
archive = await server.archive_manager.create_archive_async(
name="shared_archive", description="Archive shared by multiple agents", actor=default_user
name="shared_archive",
description="Archive shared by multiple agents",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
agent2 = await server.agent_manager.create_agent_async(
@@ -187,7 +193,10 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau
# First, create an archive that will be attached by a "concurrent" request
concurrent_archive = await server.archive_manager.create_archive_async(
name=f"{agent.name}'s Archive", description="Default archive created automatically", actor=default_user
name=f"{agent.name}'s Archive",
description="Default archive created automatically",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
call_count = 0
@@ -206,9 +215,7 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau
with patch.object(server.archive_manager, "create_archive_async", side_effect=track_create):
with patch.object(server.archive_manager, "attach_agent_to_archive_async", side_effect=failing_attach):
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=agent.id, agent_name=agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent, actor=default_user)
assert archive is not None
assert archive.id == concurrent_archive.id # Should return the existing archive
@@ -230,9 +237,7 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau
@pytest.mark.asyncio
async def test_archive_manager_get_agent_from_passage_async(server: SyncServer, default_user, sarah_agent):
"""Test getting the agent ID that owns a passage through its archive."""
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(
@@ -250,7 +255,7 @@ async def test_archive_manager_get_agent_from_passage_async(server: SyncServer,
assert agent_id == sarah_agent.id
orphan_archive = await server.archive_manager.create_archive_async(
name="orphan_archive", description="Archive with no agents", actor=default_user
name="orphan_archive", description="Archive with no agents", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
orphan_passage = await server.passage_manager.create_agent_passage_async(
@@ -278,7 +283,7 @@ async def test_archive_manager_create_archive_async(server: SyncServer, default_
"""Test creating a new archive with various parameters."""
# test creating with name and description
archive = await server.archive_manager.create_archive_async(
name="test_archive_basic", description="Test archive description", actor=default_user
name="test_archive_basic", description="Test archive description", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
assert archive.name == "test_archive_basic"
@@ -287,7 +292,9 @@ async def test_archive_manager_create_archive_async(server: SyncServer, default_
assert archive.id is not None
# test creating without description
archive2 = await server.archive_manager.create_archive_async(name="test_archive_no_desc", actor=default_user)
archive2 = await server.archive_manager.create_archive_async(
name="test_archive_no_desc", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
assert archive2.name == "test_archive_no_desc"
assert archive2.description is None
@@ -303,7 +310,7 @@ async def test_archive_manager_get_archive_by_id_async(server: SyncServer, defau
"""Test retrieving an archive by its ID."""
# create an archive
archive = await server.archive_manager.create_archive_async(
name="test_get_by_id", description="Archive to test get_by_id", actor=default_user
name="test_get_by_id", description="Archive to test get_by_id", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# retrieve the archive
@@ -327,7 +334,7 @@ async def test_archive_manager_update_archive_async(server: SyncServer, default_
"""Test updating archive name and description."""
# create an archive
archive = await server.archive_manager.create_archive_async(
name="original_name", description="original description", actor=default_user
name="original_name", description="original description", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# update name only
@@ -370,7 +377,7 @@ async def test_archive_manager_list_archives_async(server: SyncServer, default_u
archives = []
for i in range(5):
archive = await server.archive_manager.create_archive_async(
name=f"list_test_archive_{i}", description=f"Description {i}", actor=default_user
name=f"list_test_archive_{i}", description=f"Description {i}", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
archives.append(archive)
@@ -413,8 +420,12 @@ async def test_archive_manager_list_archives_async(server: SyncServer, default_u
async def test_archive_manager_attach_agent_to_archive_async(server: SyncServer, default_user, sarah_agent):
"""Test attaching agents to archives with ownership settings."""
# create archives
archive1 = await server.archive_manager.create_archive_async(name="archive_for_attachment_1", actor=default_user)
archive2 = await server.archive_manager.create_archive_async(name="archive_for_attachment_2", actor=default_user)
archive1 = await server.archive_manager.create_archive_async(
name="archive_for_attachment_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
archive2 = await server.archive_manager.create_archive_async(
name="archive_for_attachment_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# create another agent
agent2 = await server.agent_manager.create_agent_async(
@@ -466,7 +477,10 @@ async def test_archive_manager_detach_agent_from_archive_async(server: SyncServe
"""Test detaching agents from archives."""
# create archive and agents
archive = await server.archive_manager.create_archive_async(
name="archive_for_detachment", description="Test archive for detachment", actor=default_user
name="archive_for_detachment",
description="Test archive for detachment",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
agent1 = await server.agent_manager.create_agent_async(
@@ -539,7 +553,9 @@ async def test_archive_manager_detach_agent_from_archive_async(server: SyncServe
async def test_archive_manager_attach_detach_idempotency(server: SyncServer, default_user):
"""Test that attach and detach operations are idempotent."""
# create archive and agent
archive = await server.archive_manager.create_archive_async(name="idempotency_test_archive", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="idempotency_test_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
agent = await server.agent_manager.create_agent_async(
agent_create=CreateAgent(
@@ -598,8 +614,12 @@ async def test_archive_manager_attach_detach_idempotency(server: SyncServer, def
async def test_archive_manager_detach_with_multiple_archives(server: SyncServer, default_user):
"""Test detaching an agent from one archive doesn't affect others."""
# create two archives
archive1 = await server.archive_manager.create_archive_async(name="multi_archive_1", actor=default_user)
archive2 = await server.archive_manager.create_archive_async(name="multi_archive_2", actor=default_user)
archive1 = await server.archive_manager.create_archive_async(
name="multi_archive_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
archive2 = await server.archive_manager.create_archive_async(
name="multi_archive_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# create two agents
agent1 = await server.agent_manager.create_agent_async(
@@ -663,7 +683,9 @@ async def test_archive_manager_detach_with_multiple_archives(server: SyncServer,
async def test_archive_manager_detach_deleted_agent(server: SyncServer, default_user):
"""Test behavior when detaching a deleted agent."""
# create archive
archive = await server.archive_manager.create_archive_async(name="test_deleted_agent_archive", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="test_deleted_agent_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# create and attach agent
agent = await server.agent_manager.create_agent_async(
@@ -701,7 +723,10 @@ async def test_archive_manager_cascade_delete_on_archive_deletion(server: SyncSe
"""Test that deleting an archive cascades to delete relationships in archives_agents table."""
# create archive
archive = await server.archive_manager.create_archive_async(
name="archive_to_be_deleted", description="This archive will be deleted to test CASCADE", actor=default_user
name="archive_to_be_deleted",
description="This archive will be deleted to test CASCADE",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
# create multiple agents and attach them to the archive
@@ -777,7 +802,10 @@ async def test_archive_manager_list_agents_with_pagination(server: SyncServer, d
"""Test listing agents for an archive with pagination support."""
# create archive
archive = await server.archive_manager.create_archive_async(
name="pagination_test_archive", description="Archive for testing pagination", actor=default_user
name="pagination_test_archive",
description="Archive for testing pagination",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
# create multiple agents
@@ -855,7 +883,9 @@ async def test_archive_manager_get_default_archive_for_agent_async(server: SyncS
assert archive is None
# create and attach an archive
created_archive = await server.archive_manager.create_archive_async(name="default_archive", actor=default_user)
created_archive = await server.archive_manager.create_archive_async(
name="default_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
await server.archive_manager.attach_agent_to_archive_async(
agent_id=agent.id, archive_id=created_archive.id, is_owner=True, actor=default_user
@@ -875,7 +905,9 @@ async def test_archive_manager_get_default_archive_for_agent_async(server: SyncS
async def test_archive_manager_get_or_set_vector_db_namespace_async(server: SyncServer, default_user):
"""Test getting or setting vector database namespace for an archive."""
# create an archive
archive = await server.archive_manager.create_archive_async(name="test_vector_namespace", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="test_vector_namespace", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# get/set namespace for the first time
namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive_id=archive.id)
@@ -897,7 +929,10 @@ async def test_archive_manager_get_agents_with_include_parameter(server: SyncSer
"""Test getting agents for an archive with include parameter to load relationships."""
# create an archive
archive = await server.archive_manager.create_archive_async(
name="test_include_archive", description="Test archive for include parameter", actor=default_user
name="test_include_archive",
description="Test archive for include parameter",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
# create agent without base tools (to avoid needing tools in test DB)
@@ -961,7 +996,10 @@ async def test_archive_manager_delete_passage_from_archive_async(server: SyncSer
"""Test deleting a passage from an archive."""
# create archive
archive = await server.archive_manager.create_archive_async(
name="test_passage_deletion_archive", description="Archive for testing passage deletion", actor=default_user
name="test_passage_deletion_archive",
description="Archive for testing passage deletion",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
actor=default_user,
)
# create passages
@@ -1015,8 +1053,12 @@ async def test_archive_manager_delete_passage_from_archive_async(server: SyncSer
async def test_archive_manager_delete_passage_from_wrong_archive(server: SyncServer, default_user):
"""Test that deleting a passage from the wrong archive raises an error."""
# create two archives
archive1 = await server.archive_manager.create_archive_async(name="archive_1", actor=default_user)
archive2 = await server.archive_manager.create_archive_async(name="archive_2", actor=default_user)
archive1 = await server.archive_manager.create_archive_async(
name="archive_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
archive2 = await server.archive_manager.create_archive_async(
name="archive_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# create passage in archive1
passage = await server.passage_manager.create_agent_passage_async(
@@ -1048,7 +1090,9 @@ async def test_archive_manager_delete_passage_from_wrong_archive(server: SyncSer
async def test_archive_manager_delete_nonexistent_passage(server: SyncServer, default_user):
"""Test that deleting a non-existent passage raises an error."""
# create archive
archive = await server.archive_manager.create_archive_async(name="test_nonexistent_passage_archive", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="test_nonexistent_passage_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
# attempt to delete non-existent passage (use valid UUID4 format)
fake_passage_id = f"passage-{uuid.uuid4()}"
@@ -1065,7 +1109,9 @@ async def test_archive_manager_delete_nonexistent_passage(server: SyncServer, de
async def test_archive_manager_delete_passage_from_nonexistent_archive(server: SyncServer, default_user):
"""Test that deleting a passage from a non-existent archive raises an error."""
# create archive and passage
archive = await server.archive_manager.create_archive_async(name="temp_archive", actor=default_user)
archive = await server.archive_manager.create_archive_async(
name="temp_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
)
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(

View File

@@ -54,7 +54,7 @@ from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import (
@@ -245,9 +245,7 @@ async def test_agent_list_passages_vector_search(
embed_model = mock_embed_model
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create passages with known embeddings
passages = []
@@ -429,9 +427,7 @@ async def test_passage_cascade_deletion(
async def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
"""Test creating an agent passage using the new agent-specific method."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(
@@ -494,9 +490,7 @@ async def test_create_agent_passage_validation(server: SyncServer, default_user,
)
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Should fail if source_id is present
with pytest.raises(ValueError, match="Agent passage cannot have source_id"):
@@ -530,9 +524,7 @@ async def test_create_source_passage_validation(server: SyncServer, default_user
)
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Should fail if archive_id is present
with pytest.raises(ValueError, match="Source passage cannot have archive_id"):
@@ -554,9 +546,7 @@ async def test_create_source_passage_validation(server: SyncServer, default_user
async def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent):
"""Test retrieving an agent passage using the new agent-specific method."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create an agent passage
passage = await server.passage_manager.create_agent_passage_async(
@@ -608,9 +598,7 @@ async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sa
"""Test that trying to get the wrong passage type with specific methods fails."""
# Create an agent passage
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
agent_passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(
@@ -650,9 +638,7 @@ async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sa
async def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
"""Test updating an agent passage using the new agent-specific method."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create an agent passage
passage = await server.passage_manager.create_agent_passage_async(
@@ -724,9 +710,7 @@ async def test_update_source_passage_specific(server: SyncServer, default_user,
async def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
"""Test deleting an agent passage using the new agent-specific method."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create an agent passage
passage = await server.passage_manager.create_agent_passage_async(
@@ -787,9 +771,7 @@ async def test_delete_source_passage_specific(server: SyncServer, default_user,
async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent):
"""Test creating multiple agent passages using the new batch method."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
passages = [
PydanticPassage(
@@ -846,9 +828,7 @@ async def test_agent_passage_size(server: SyncServer, default_user, sarah_agent)
initial_size = await server.passage_manager.agent_passage_size_async(actor=default_user, agent_id=sarah_agent.id)
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create some agent passages
for i in range(3):
@@ -873,9 +853,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe
from letta.schemas.enums import TagMatchMode
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create passages with different tag combinations
test_passages = [
@@ -969,8 +947,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync
# Test 2: Verify unique tags for archive
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id,
agent_name=sarah_agent.name,
agent_state=sarah_agent,
actor=default_user,
)
@@ -1194,8 +1171,7 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age
# Verify unique tags includes all special character tags
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id,
agent_name=sarah_agent.name,
agent_state=sarah_agent,
actor=default_user,
)
@@ -1239,9 +1215,7 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age
async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent):
"""Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
# Create test passages with various content and tags
test_data = [