From 5c0691804280c0d532096efb00ec415814d0321e Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Wed, 21 Jan 2026 16:37:21 -0800 Subject: [PATCH] fix: don't need embedding model for self hosted [LET-7009] (#8935) * fix: don't need embedding model for self hosted * stage publish api * passes tests * add test * remove unnecessary upgrades * update revision order db migrations * add timeout for ci --- ...52_nullable_embedding_for_archives_and_.py | 36 ++ fern/openapi.json | 16 +- letta/orm/archive.py | 4 +- letta/orm/passage.py | 8 +- letta/schemas/archive.py | 2 +- letta/server/rest_api/routers/v1/archives.py | 18 +- letta/server/server.py | 16 +- letta/services/agent_manager.py | 4 +- letta/services/archive_manager.py | 25 +- .../services/helpers/agent_manager_helper.py | 4 +- letta/services/passage_manager.py | 22 +- tests/mcp_tests/test_mcp.py | 3 +- tests/test_embedding_optional.py | 454 ++++++++++++++++++ tests/test_sources.py | 3 +- 14 files changed, 554 insertions(+), 61 deletions(-) create mode 100644 alembic/versions/297e8217e952_nullable_embedding_for_archives_and_.py create mode 100644 tests/test_embedding_optional.py diff --git a/alembic/versions/297e8217e952_nullable_embedding_for_archives_and_.py b/alembic/versions/297e8217e952_nullable_embedding_for_archives_and_.py new file mode 100644 index 00000000..69aa8f05 --- /dev/null +++ b/alembic/versions/297e8217e952_nullable_embedding_for_archives_and_.py @@ -0,0 +1,36 @@ +"""nullable embedding for archives and passages + +Revision ID: 297e8217e952 +Revises: 308a180244fc +Create Date: 2026-01-20 14:11:21.137232 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "297e8217e952" +down_revision: Union[str, None] = "308a180244fc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index f664ea0f..618b0ca0 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -26203,7 +26203,14 @@ "default": "native" }, "embedding_config": { - "$ref": "#/components/schemas/EmbeddingConfig", + "anyOf": [ + { + "$ref": "#/components/schemas/EmbeddingConfig" + }, + { + "type": "null" + } + ], "description": "Embedding configuration for passages in this archive" }, "metadata": { @@ -26229,12 +26236,7 @@ }, "additionalProperties": false, "type": "object", - "required": [ - "created_at", - "name", - "organization_id", - "embedding_config" - ], + "required": ["created_at", "name", "organization_id"], "title": "Archive", "description": "Representation of an archive - a collection of archival passages that can be shared between agents." }, diff --git a/letta/orm/archive.py b/letta/orm/archive.py index 16e0fddf..932e37af 100644 --- a/letta/orm/archive.py +++ b/letta/orm/archive.py @@ -46,8 +46,8 @@ class Archive(SqlalchemyBase, OrganizationMixin): default=VectorDBProvider.NATIVE, doc="The vector database provider used for this archive's passages", ) - embedding_config: Mapped[dict] = mapped_column( - EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive" + embedding_config: Mapped[Optional[dict]] = mapped_column( + EmbeddingConfigColumn, nullable=True, doc="Embedding configuration for passages in this archive" ) metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive") _vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace") diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 404830f2..4bbe9509 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -25,18 +25,18 @@ class BasePassage(SqlalchemyBase, OrganizationMixin): id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier") text: Mapped[str] = mapped_column(doc="Passage text content") - embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration") + embedding_config: Mapped[Optional[dict]] = mapped_column(EmbeddingConfigColumn, nullable=True, doc="Embedding configuration") metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata") # dual storage: json column for fast retrieval, junction table for efficient queries tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage") - # Vector embedding field based on database type + # Vector embedding field based on database type - nullable for text-only search if settings.database_engine is DatabaseChoice.POSTGRES: from pgvector.sqlalchemy import Vector - embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) + embedding = mapped_column(Vector(MAX_EMBEDDING_DIM), nullable=True) else: - embedding = Column(CommonVector) + embedding = Column(CommonVector, nullable=True) @declared_attr def organization(cls) -> Mapped["Organization"]: diff --git a/letta/schemas/archive.py b/letta/schemas/archive.py index 08c863d9..f0d0348f 100644 --- a/letta/schemas/archive.py +++ b/letta/schemas/archive.py @@ -17,7 +17,7 @@ class ArchiveBase(OrmMetadataBase): vector_db_provider: VectorDBProvider = Field( default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages" ) - embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive") + embedding_config: Optional[EmbeddingConfig] = Field(None, description="Embedding configuration for passages in this archive") metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata") diff --git a/letta/server/rest_api/routers/v1/archives.py b/letta/server/rest_api/routers/v1/archives.py index c109f1c8..34ac92ac 100644 --- a/letta/server/rest_api/routers/v1/archives.py +++ b/letta/server/rest_api/routers/v1/archives.py @@ -65,16 +65,14 @@ async def create_archive( if embedding_config is None: embedding_handle = archive.embedding if embedding_handle is None: - if settings.default_embedding_handle is None: - raise LettaInvalidArgumentError( - "Must specify either embedding or embedding_config in request", argument_name="default_embedding_handle" - ) - else: - embedding_handle = settings.default_embedding_handle - embedding_config = await server.get_embedding_config_from_handle_async( - handle=embedding_handle, - actor=actor, - ) + embedding_handle = settings.default_embedding_handle + # Only resolve embedding config if we have an embedding handle + if embedding_handle is not None: + embedding_config = await server.get_embedding_config_from_handle_async( + handle=embedding_handle, + actor=actor, + ) + # Otherwise, embedding_config remains None (text search only) return await server.archive_manager.create_archive_async( name=archive.name, diff --git a/letta/server/server.py b/letta/server/server.py index eea548e9..62dde129 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -19,7 +19,6 @@ from letta.config import LettaConfig from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.data_sources.connectors import DataConnector, load_data from letta.errors import ( - EmbeddingConfigRequiredError, HandleNotFoundError, LettaInvalidArgumentError, LettaMCPConnectionError, @@ -649,9 +648,10 @@ class SyncServer(object): actor=actor, ) - async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState: + async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]: if main_agent.embedding_config is None: - raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_sleeptime_agent") + logger.warning(f"Skipping sleeptime agent creation for agent {main_agent.id}: no embedding config provided") + return None request = CreateAgent( name=main_agent.name + "-sleeptime", agent_type=AgentType.sleeptime_agent, @@ -683,9 +683,10 @@ class SyncServer(object): ) return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor) - async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState: + async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]: if main_agent.embedding_config is None: - raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_voice_sleeptime_agent") + logger.warning(f"Skipping voice sleeptime agent creation for agent {main_agent.id}: no embedding config provided") + return None # TODO: Inject system request = CreateAgent( name=main_agent.name + "-sleeptime", @@ -1062,9 +1063,10 @@ class SyncServer(object): async def create_document_sleeptime_agent_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False - ) -> AgentState: + ) -> Optional[AgentState]: if main_agent.embedding_config is None: - raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_document_sleeptime_agent") + logger.warning(f"Skipping document sleeptime agent creation for agent {main_agent.id}: no embedding config provided") + return None try: block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor) except: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 986c80c7..eaee6ec3 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -2479,13 +2479,15 @@ class AgentManager: # Get results using existing passage query method limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + # Only use embedding-based search if embedding config is available + use_embedding_search = agent_state.embedding_config is not None passages_with_metadata = await self.query_agent_passages_async( actor=actor, agent_id=agent_id, query_text=query, limit=limit, embedding_config=agent_state.embedding_config, - embed_query=True, + embed_query=use_embedding_search, tags=tags, tag_match_mode=tag_mode, start_date=start_date, diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 5265e2c8..fd98a56a 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional from sqlalchemy import delete, or_, select -from letta.errors import EmbeddingConfigRequiredError from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents @@ -32,7 +31,7 @@ class ArchiveManager: async def create_archive_async( self, name: str, - embedding_config: EmbeddingConfig, + embedding_config: Optional[EmbeddingConfig] = None, description: Optional[str] = None, actor: PydanticUser = None, ) -> PydanticArchive: @@ -312,15 +311,17 @@ class ArchiveManager: # Verify the archive exists and user has access archive = await self.get_archive_by_id_async(archive_id=archive_id, actor=actor) - # Generate embeddings for the text - embedding_client = LLMClient.create( - provider_type=archive.embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await embedding_client.request_embeddings([text], archive.embedding_config) - embedding = embeddings[0] if embeddings else None + # Generate embeddings for the text if embedding config is available + embedding = None + if archive.embedding_config is not None: + embedding_client = LLMClient.create( + provider_type=archive.embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([text], archive.embedding_config) + embedding = embeddings[0] if embeddings else None - # Create the passage object with embedding + # Create the passage object (with or without embedding) passage = PydanticPassage( text=text, archive_id=archive_id, @@ -434,9 +435,7 @@ class ArchiveManager: ) return archive - # Create a default archive for this agent - if agent_state.embedding_config is None: - raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="create_default_archive") + # Create a default archive for this agent (embedding_config is optional) archive_name = f"{agent_state.name}'s Archive" archive = await self.create_archive_async( name=archive_name, diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 542b6501..1bcf6683 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1194,9 +1194,9 @@ async def build_agent_passage_query( """ # Handle embedding for vector search + # If embed_query is True but no embedding_config, fall through to text search embedded_text = None - if embed_query: - assert embedding_config is not None, "embedding_config must be specified for vector search" + if embed_query and embedding_config is not None: assert query_text is not None, "query_text must be specified for vector search" # Use the new LLMClient for embeddings diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index f2d7e48a..c1e6bad3 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -8,7 +8,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload from letta.constants import MAX_EMBEDDING_DIM -from letta.errors import EmbeddingConfigRequiredError from letta.helpers.decorators import async_redis_cache from letta.llm_api.llm_client import LLMClient from letta.log import get_logger @@ -474,15 +473,6 @@ class PassageManager: Returns: List of created passage objects """ - if agent_state.embedding_config is None: - raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="insert_passage") - - embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size - embedding_client = LLMClient.create( - provider_type=agent_state.embedding_config.embedding_endpoint_type, - actor=actor, - ) - # Get or create the default archive for the agent archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor) @@ -493,8 +483,16 @@ class PassageManager: return [] try: - # Generate embeddings for all chunks using the new async API - embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config) + # Generate embeddings if embedding config is available + if agent_state.embedding_config is not None: + embedding_client = LLMClient.create( + provider_type=agent_state.embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config) + else: + # No embedding config - store passages without embeddings (text search only) + embeddings = [None] * len(text_chunks) passages = [] diff --git a/tests/mcp_tests/test_mcp.py b/tests/mcp_tests/test_mcp.py index f63882a4..9930a0cd 100644 --- a/tests/mcp_tests/test_mcp.py +++ b/tests/mcp_tests/test_mcp.py @@ -115,7 +115,8 @@ def server_url(empty_mcp_config): if not os.getenv("LETTA_SERVER_URL"): thread = threading.Thread(target=_run_server, daemon=True) thread.start() - wait_for_server(url) + # Use 60s timeout to allow for provider model syncing during server startup + wait_for_server(url, timeout=60) return url diff --git a/tests/test_embedding_optional.py b/tests/test_embedding_optional.py new file mode 100644 index 00000000..882834a1 --- /dev/null +++ b/tests/test_embedding_optional.py @@ -0,0 +1,454 @@ +""" +Tests for embedding-optional archival memory feature. + +This file tests that agents can be created without an embedding model +and that archival memory operations (insert, list, search) work correctly +using text-based search when no embeddings are available. +""" + +import os +import threading +import warnings + +import pytest +from dotenv import load_dotenv +from letta_client import Letta as LettaSDKClient +from letta_client.types import CreateBlockParam + +from tests.utils import wait_for_server + +# Constants +SERVER_PORT = 8283 + + +def run_server(): + load_dotenv() + from letta.server.rest_api.app import start_server + + print("Starting server...") + start_server(debug=True) + + +@pytest.fixture(scope="module") +def client() -> LettaSDKClient: + """Get or start a Letta server and return a client.""" + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}") + if not os.getenv("LETTA_SERVER_URL"): + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + wait_for_server(server_url, timeout=60) + + print("Running embedding-optional tests with server:", server_url) + client = LettaSDKClient(base_url=server_url) + yield client + + +@pytest.fixture(scope="function") +def agent_without_embedding(client: LettaSDKClient): + """Create an agent without an embedding model for testing.""" + agent_state = client.agents.create( + memory_blocks=[ + CreateBlockParam( + label="human", + value="username: test_user", + ), + ], + model="openai/gpt-4o-mini", + # NOTE: Intentionally NOT providing embedding parameter + # to test embedding-optional functionality + ) + + assert agent_state.embedding_config is None, "Agent should have no embedding config" + + yield agent_state + + # Cleanup + client.agents.delete(agent_id=agent_state.id) + + +@pytest.fixture(scope="function") +def agent_with_embedding(client: LettaSDKClient): + """Create an agent WITH an embedding model for comparison testing.""" + agent_state = client.agents.create( + memory_blocks=[ + CreateBlockParam( + label="human", + value="username: test_user_with_embedding", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + assert agent_state.embedding_config is not None, "Agent should have embedding config" + + yield agent_state + + # Cleanup + client.agents.delete(agent_id=agent_state.id) + + +class TestAgentCreationWithoutEmbedding: + """Tests for agent creation without embedding configuration.""" + + def test_create_agent_without_embedding(self, client: LettaSDKClient): + """Test that an agent can be created without an embedding model.""" + agent_state = client.agents.create( + memory_blocks=[ + CreateBlockParam( + label="human", + value="test user", + ), + ], + model="openai/gpt-4o-mini", + ) + + try: + assert agent_state.id is not None + assert agent_state.id.startswith("agent-") + assert agent_state.embedding_config is None + assert agent_state.llm_config is not None + finally: + client.agents.delete(agent_id=agent_state.id) + + def test_agent_with_and_without_embedding_coexist(self, agent_without_embedding, agent_with_embedding): + """Test that agents with and without embedding can coexist.""" + assert agent_without_embedding.id != agent_with_embedding.id + assert agent_without_embedding.embedding_config is None + assert agent_with_embedding.embedding_config is not None + + +class TestArchivalMemoryInsertWithoutEmbedding: + """Tests for inserting archival memory without embeddings.""" + + def test_insert_passage_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test inserting a passage into an agent without embedding config.""" + agent_id = agent_without_embedding.id + + # Insert a passage - use deprecated API but suppress warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + passages = client.agents.passages.create( + agent_id=agent_id, + text="This is a test passage about Python programming.", + ) + + # Should return a list with one passage + assert len(passages) == 1 + passage = passages[0] + + assert passage.id is not None + assert passage.text == "This is a test passage about Python programming." + # Embedding should be None for agents without embedding config + assert passage.embedding is None + assert passage.embedding_config is None + + def test_insert_multiple_passages_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test inserting multiple passages into an agent without embedding.""" + agent_id = agent_without_embedding.id + + test_passages = [ + "Machine learning is a subset of artificial intelligence.", + "Python is widely used for data science applications.", + "Neural networks can learn complex patterns from data.", + ] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + for text in test_passages: + passages = client.agents.passages.create( + agent_id=agent_id, + text=text, + ) + assert len(passages) == 1 + assert passages[0].embedding is None + + # Verify all passages were inserted + all_passages = client.agents.passages.list(agent_id=agent_id) + + assert len(all_passages) >= 3 + + def test_insert_passage_with_tags_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test inserting a passage with tags into an agent without embedding.""" + agent_id = agent_without_embedding.id + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + passages = client.agents.passages.create( + agent_id=agent_id, + text="Important fact: The sky is blue due to Rayleigh scattering.", + tags=["science", "physics", "important"], + ) + + assert len(passages) == 1 + passage = passages[0] + assert passage.embedding is None + assert passage.tags is not None + assert set(passage.tags) == {"science", "physics", "important"} + + +class TestArchivalMemoryListWithoutEmbedding: + """Tests for listing archival memory without embeddings.""" + + def test_list_passages_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test listing passages from an agent without embedding.""" + agent_id = agent_without_embedding.id + + # Insert some passages first + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + client.agents.passages.create( + agent_id=agent_id, + text="First test passage", + ) + client.agents.passages.create( + agent_id=agent_id, + text="Second test passage", + ) + + # List passages + passages = client.agents.passages.list(agent_id=agent_id) + + assert len(passages) >= 2 + + for passage in passages: + # Verify embeddings are None + assert passage.embedding is None + + def test_list_passages_with_search_filter(self, client: LettaSDKClient, agent_without_embedding): + """Test listing passages with text search filter.""" + agent_id = agent_without_embedding.id + + # Insert passages with distinctive content + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + client.agents.passages.create( + agent_id=agent_id, + text="Apple is a fruit that grows on trees.", + ) + client.agents.passages.create( + agent_id=agent_id, + text="Python is a programming language.", + ) + + # Search for passages containing "fruit" + passages = client.agents.passages.list( + agent_id=agent_id, + search="fruit", + ) + + # Should find the apple passage + assert len(passages) >= 1 + assert any("fruit" in p.text.lower() for p in passages) + + +class TestArchivalMemorySearchWithoutEmbedding: + """Tests for searching archival memory without embeddings (text-based search).""" + + def test_search_passages_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test searching passages using text search (no embeddings).""" + agent_id = agent_without_embedding.id + + # Insert test passages + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + client.agents.passages.create( + agent_id=agent_id, + text="The capital of France is Paris.", + ) + client.agents.passages.create( + agent_id=agent_id, + text="Tokyo is the capital of Japan.", + ) + client.agents.passages.create( + agent_id=agent_id, + text="Python is a popular programming language.", + ) + + # Search for passages about capitals + results = client.agents.passages.search( + agent_id=agent_id, + query="capital", + ) + + # Should find passages about capitals (text search) + assert results is not None + # Check results structure - might be a response object + if hasattr(results, "results"): + assert len(results.results) >= 1 + elif hasattr(results, "__len__"): + assert len(results) >= 0 # Might be empty if text search returns 0 + + def test_global_passage_search_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test global passage search endpoint for agent without embedding.""" + agent_id = agent_without_embedding.id + + # Insert a passage + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + client.agents.passages.create( + agent_id=agent_id, + text="Unique test content for global search testing xyz123.", + ) + + # Use global passage search + results = client.passages.search( + query="xyz123", + agent_id=agent_id, + ) + + # Should find the passage using text search + assert results is not None + + +class TestArchivalMemoryDeleteWithoutEmbedding: + """Tests for deleting archival memory without embeddings.""" + + def test_delete_passage_without_embedding(self, client: LettaSDKClient, agent_without_embedding): + """Test deleting a passage from an agent without embedding.""" + agent_id = agent_without_embedding.id + + # Insert a passage + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + passages = client.agents.passages.create( + agent_id=agent_id, + text="Passage to be deleted", + ) + + passage_id = passages[0].id + + # Delete the passage + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + client.agents.passages.delete( + agent_id=agent_id, + memory_id=passage_id, + ) + + # Verify it's deleted - should not appear in list + remaining = client.agents.passages.list(agent_id=agent_id) + + assert all(p.id != passage_id for p in remaining) + + +class TestComparisonWithAndWithoutEmbedding: + """Compare behavior between agents with and without embedding config.""" + + def test_passage_insert_comparison( + self, + client: LettaSDKClient, + agent_without_embedding, + agent_with_embedding, + ): + """Compare passage insertion between agents with/without embedding.""" + test_text = "Comparison test: This is identical content for both agents." + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + # Insert into agent without embedding + passages_no_embed = client.agents.passages.create( + agent_id=agent_without_embedding.id, + text=test_text, + ) + + # Insert into agent with embedding + passages_with_embed = client.agents.passages.create( + agent_id=agent_with_embedding.id, + text=test_text, + ) + + # Both should succeed + assert len(passages_no_embed) == 1 + assert len(passages_with_embed) == 1 + + # Text should be identical + assert passages_no_embed[0].text == passages_with_embed[0].text + + # Embedding status should differ + assert passages_no_embed[0].embedding is None + assert passages_with_embed[0].embedding is not None + + def test_list_passages_comparison( + self, + client: LettaSDKClient, + agent_without_embedding, + agent_with_embedding, + ): + """Compare passage listing between agents with/without embedding.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + # Insert passages into both agents + client.agents.passages.create( + agent_id=agent_without_embedding.id, + text="Test passage for listing comparison", + ) + client.agents.passages.create( + agent_id=agent_with_embedding.id, + text="Test passage for listing comparison", + ) + + # List from both agents + passages_no_embed = client.agents.passages.list(agent_id=agent_without_embedding.id) + passages_with_embed = client.agents.passages.list(agent_id=agent_with_embedding.id) + + # Both should return passages + assert len(passages_no_embed) >= 1 + assert len(passages_with_embed) >= 1 + + +class TestEdgeCases: + """Edge cases and error handling for embedding-optional feature.""" + + def test_empty_archival_memory_search(self, client: LettaSDKClient, agent_without_embedding): + """Test searching an empty archival memory.""" + agent_id = agent_without_embedding.id + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Search without any passages - should return empty, not error + results = client.agents.passages.search( + agent_id=agent_id, + query="anything", + ) + + # Should return empty results, not raise an error + assert results is not None + + def test_passage_with_special_characters(self, client: LettaSDKClient, agent_without_embedding): + """Test inserting passages with special characters.""" + agent_id = agent_without_embedding.id + + special_text = "Special chars: @#$%^&*() 日本語 émojis 🎉 " + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + passages = client.agents.passages.create( + agent_id=agent_id, + text=special_text, + ) + + assert len(passages) == 1 + assert passages[0].text == special_text + assert passages[0].embedding is None + + def test_very_long_passage(self, client: LettaSDKClient, agent_without_embedding): + """Test inserting a very long passage.""" + agent_id = agent_without_embedding.id + + # Create a long text (10KB) + long_text = "This is a test. " * 1000 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + passages = client.agents.passages.create( + agent_id=agent_id, + text=long_text, + ) + + assert len(passages) >= 1 # Might be chunked + # First passage should have no embedding + assert passages[0].embedding is None diff --git a/tests/test_sources.py b/tests/test_sources.py index 2bf68edc..79555784 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -68,7 +68,8 @@ def client() -> LettaSDKClient: print("Starting server thread") thread = threading.Thread(target=run_server, daemon=True) thread.start() - wait_for_server(server_url) + # Use 60s timeout to allow for provider model syncing during server startup + wait_for_server(server_url, timeout=60) print("Running client tests with server:", server_url) client = LettaSDKClient(base_url=server_url) yield client