fix: don't need embedding model for self hosted [LET-7009] (#8935)
* fix: don't need embedding model for self hosted * stage publish api * passes tests * add test * remove unnecessary upgrades * update revision order db migrations * add timeout for ci
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
"""nullable embedding for archives and passages
|
||||
|
||||
Revision ID: 297e8217e952
|
||||
Revises: 308a180244fc
|
||||
Create Date: 2026-01-20 14:11:21.137232
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "297e8217e952"
|
||||
down_revision: Union[str, None] = "308a180244fc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -26203,7 +26203,14 @@
|
||||
"default": "native"
|
||||
},
|
||||
"embedding_config": {
|
||||
"$ref": "#/components/schemas/EmbeddingConfig",
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/EmbeddingConfig"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Embedding configuration for passages in this archive"
|
||||
},
|
||||
"metadata": {
|
||||
@@ -26229,12 +26236,7 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": [
|
||||
"created_at",
|
||||
"name",
|
||||
"organization_id",
|
||||
"embedding_config"
|
||||
],
|
||||
"required": ["created_at", "name", "organization_id"],
|
||||
"title": "Archive",
|
||||
"description": "Representation of an archive - a collection of archival passages that can be shared between agents."
|
||||
},
|
||||
|
||||
@@ -46,8 +46,8 @@ class Archive(SqlalchemyBase, OrganizationMixin):
|
||||
default=VectorDBProvider.NATIVE,
|
||||
doc="The vector database provider used for this archive's passages",
|
||||
)
|
||||
embedding_config: Mapped[dict] = mapped_column(
|
||||
EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive"
|
||||
embedding_config: Mapped[Optional[dict]] = mapped_column(
|
||||
EmbeddingConfigColumn, nullable=True, doc="Embedding configuration for passages in this archive"
|
||||
)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
|
||||
_vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace")
|
||||
|
||||
@@ -25,18 +25,18 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
||||
text: Mapped[str] = mapped_column(doc="Passage text content")
|
||||
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
||||
embedding_config: Mapped[Optional[dict]] = mapped_column(EmbeddingConfigColumn, nullable=True, doc="Embedding configuration")
|
||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||
# dual storage: json column for fast retrieval, junction table for efficient queries
|
||||
tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage")
|
||||
|
||||
# Vector embedding field based on database type
|
||||
# Vector embedding field based on database type - nullable for text-only search
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM), nullable=True)
|
||||
else:
|
||||
embedding = Column(CommonVector)
|
||||
embedding = Column(CommonVector, nullable=True)
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
|
||||
@@ -17,7 +17,7 @@ class ArchiveBase(OrmMetadataBase):
|
||||
vector_db_provider: VectorDBProvider = Field(
|
||||
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
|
||||
)
|
||||
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="Embedding configuration for passages in this archive")
|
||||
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")
|
||||
|
||||
|
||||
|
||||
@@ -65,16 +65,14 @@ async def create_archive(
|
||||
if embedding_config is None:
|
||||
embedding_handle = archive.embedding
|
||||
if embedding_handle is None:
|
||||
if settings.default_embedding_handle is None:
|
||||
raise LettaInvalidArgumentError(
|
||||
"Must specify either embedding or embedding_config in request", argument_name="default_embedding_handle"
|
||||
)
|
||||
else:
|
||||
embedding_handle = settings.default_embedding_handle
|
||||
embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=embedding_handle,
|
||||
actor=actor,
|
||||
)
|
||||
embedding_handle = settings.default_embedding_handle
|
||||
# Only resolve embedding config if we have an embedding handle
|
||||
if embedding_handle is not None:
|
||||
embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=embedding_handle,
|
||||
actor=actor,
|
||||
)
|
||||
# Otherwise, embedding_config remains None (text search only)
|
||||
|
||||
return await server.archive_manager.create_archive_async(
|
||||
name=archive.name,
|
||||
|
||||
@@ -19,7 +19,6 @@ from letta.config import LettaConfig
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import (
|
||||
EmbeddingConfigRequiredError,
|
||||
HandleNotFoundError,
|
||||
LettaInvalidArgumentError,
|
||||
LettaMCPConnectionError,
|
||||
@@ -649,9 +648,10 @@ class SyncServer(object):
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_sleeptime_agent")
|
||||
logger.warning(f"Skipping sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
||||
return None
|
||||
request = CreateAgent(
|
||||
name=main_agent.name + "-sleeptime",
|
||||
agent_type=AgentType.sleeptime_agent,
|
||||
@@ -683,9 +683,10 @@ class SyncServer(object):
|
||||
)
|
||||
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
|
||||
|
||||
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_voice_sleeptime_agent")
|
||||
logger.warning(f"Skipping voice sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
||||
return None
|
||||
# TODO: Inject system
|
||||
request = CreateAgent(
|
||||
name=main_agent.name + "-sleeptime",
|
||||
@@ -1062,9 +1063,10 @@ class SyncServer(object):
|
||||
|
||||
async def create_document_sleeptime_agent_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
) -> AgentState:
|
||||
) -> Optional[AgentState]:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_document_sleeptime_agent")
|
||||
logger.warning(f"Skipping document sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
||||
return None
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor)
|
||||
except:
|
||||
|
||||
@@ -2479,13 +2479,15 @@ class AgentManager:
|
||||
|
||||
# Get results using existing passage query method
|
||||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
# Only use embedding-based search if embedding config is available
|
||||
use_embedding_search = agent_state.embedding_config is not None
|
||||
passages_with_metadata = await self.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
embed_query=use_embedding_search,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=start_date,
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
from letta.errors import EmbeddingConfigRequiredError
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||
@@ -32,7 +31,7 @@ class ArchiveManager:
|
||||
async def create_archive_async(
|
||||
self,
|
||||
name: str,
|
||||
embedding_config: EmbeddingConfig,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
description: Optional[str] = None,
|
||||
actor: PydanticUser = None,
|
||||
) -> PydanticArchive:
|
||||
@@ -312,15 +311,17 @@ class ArchiveManager:
|
||||
# Verify the archive exists and user has access
|
||||
archive = await self.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
|
||||
# Generate embeddings for the text
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=archive.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], archive.embedding_config)
|
||||
embedding = embeddings[0] if embeddings else None
|
||||
# Generate embeddings for the text if embedding config is available
|
||||
embedding = None
|
||||
if archive.embedding_config is not None:
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=archive.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], archive.embedding_config)
|
||||
embedding = embeddings[0] if embeddings else None
|
||||
|
||||
# Create the passage object with embedding
|
||||
# Create the passage object (with or without embedding)
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
archive_id=archive_id,
|
||||
@@ -434,9 +435,7 @@ class ArchiveManager:
|
||||
)
|
||||
return archive
|
||||
|
||||
# Create a default archive for this agent
|
||||
if agent_state.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="create_default_archive")
|
||||
# Create a default archive for this agent (embedding_config is optional)
|
||||
archive_name = f"{agent_state.name}'s Archive"
|
||||
archive = await self.create_archive_async(
|
||||
name=archive_name,
|
||||
|
||||
@@ -1194,9 +1194,9 @@ async def build_agent_passage_query(
|
||||
"""
|
||||
|
||||
# Handle embedding for vector search
|
||||
# If embed_query is True but no embedding_config, fall through to text search
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
||||
if embed_query and embedding_config is not None:
|
||||
assert query_text is not None, "query_text must be specified for vector search"
|
||||
|
||||
# Use the new LLMClient for embeddings
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.errors import EmbeddingConfigRequiredError
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
@@ -474,15 +473,6 @@ class PassageManager:
|
||||
Returns:
|
||||
List of created passage objects
|
||||
"""
|
||||
if agent_state.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="insert_passage")
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=agent_state.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Get or create the default archive for the agent
|
||||
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
@@ -493,8 +483,16 @@ class PassageManager:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate embeddings for all chunks using the new async API
|
||||
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
|
||||
# Generate embeddings if embedding config is available
|
||||
if agent_state.embedding_config is not None:
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=agent_state.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
|
||||
else:
|
||||
# No embedding config - store passages without embeddings (text search only)
|
||||
embeddings = [None] * len(text_chunks)
|
||||
|
||||
passages = []
|
||||
|
||||
|
||||
@@ -115,7 +115,8 @@ def server_url(empty_mcp_config):
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url)
|
||||
# Use 60s timeout to allow for provider model syncing during server startup
|
||||
wait_for_server(url, timeout=60)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
454
tests/test_embedding_optional.py
Normal file
454
tests/test_embedding_optional.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
Tests for embedding-optional archival memory feature.
|
||||
|
||||
This file tests that agents can be created without an embedding model
|
||||
and that archival memory operations (insert, list, search) work correctly
|
||||
using text-based search when no embeddings are available.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client.types import CreateBlockParam
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client() -> LettaSDKClient:
|
||||
"""Get or start a Letta server and return a client."""
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url, timeout=60)
|
||||
|
||||
print("Running embedding-optional tests with server:", server_url)
|
||||
client = LettaSDKClient(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_without_embedding(client: LettaSDKClient):
|
||||
"""Create an agent without an embedding model for testing."""
|
||||
agent_state = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlockParam(
|
||||
label="human",
|
||||
value="username: test_user",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
# NOTE: Intentionally NOT providing embedding parameter
|
||||
# to test embedding-optional functionality
|
||||
)
|
||||
|
||||
assert agent_state.embedding_config is None, "Agent should have no embedding config"
|
||||
|
||||
yield agent_state
|
||||
|
||||
# Cleanup
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_with_embedding(client: LettaSDKClient):
|
||||
"""Create an agent WITH an embedding model for comparison testing."""
|
||||
agent_state = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlockParam(
|
||||
label="human",
|
||||
value="username: test_user_with_embedding",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
assert agent_state.embedding_config is not None, "Agent should have embedding config"
|
||||
|
||||
yield agent_state
|
||||
|
||||
# Cleanup
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
class TestAgentCreationWithoutEmbedding:
|
||||
"""Tests for agent creation without embedding configuration."""
|
||||
|
||||
def test_create_agent_without_embedding(self, client: LettaSDKClient):
|
||||
"""Test that an agent can be created without an embedding model."""
|
||||
agent_state = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlockParam(
|
||||
label="human",
|
||||
value="test user",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
try:
|
||||
assert agent_state.id is not None
|
||||
assert agent_state.id.startswith("agent-")
|
||||
assert agent_state.embedding_config is None
|
||||
assert agent_state.llm_config is not None
|
||||
finally:
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
def test_agent_with_and_without_embedding_coexist(self, agent_without_embedding, agent_with_embedding):
|
||||
"""Test that agents with and without embedding can coexist."""
|
||||
assert agent_without_embedding.id != agent_with_embedding.id
|
||||
assert agent_without_embedding.embedding_config is None
|
||||
assert agent_with_embedding.embedding_config is not None
|
||||
|
||||
|
||||
class TestArchivalMemoryInsertWithoutEmbedding:
|
||||
"""Tests for inserting archival memory without embeddings."""
|
||||
|
||||
def test_insert_passage_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test inserting a passage into an agent without embedding config."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Insert a passage - use deprecated API but suppress warning
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
passages = client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="This is a test passage about Python programming.",
|
||||
)
|
||||
|
||||
# Should return a list with one passage
|
||||
assert len(passages) == 1
|
||||
passage = passages[0]
|
||||
|
||||
assert passage.id is not None
|
||||
assert passage.text == "This is a test passage about Python programming."
|
||||
# Embedding should be None for agents without embedding config
|
||||
assert passage.embedding is None
|
||||
assert passage.embedding_config is None
|
||||
|
||||
def test_insert_multiple_passages_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test inserting multiple passages into an agent without embedding."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
test_passages = [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is widely used for data science applications.",
|
||||
"Neural networks can learn complex patterns from data.",
|
||||
]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
for text in test_passages:
|
||||
passages = client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text=text,
|
||||
)
|
||||
assert len(passages) == 1
|
||||
assert passages[0].embedding is None
|
||||
|
||||
# Verify all passages were inserted
|
||||
all_passages = client.agents.passages.list(agent_id=agent_id)
|
||||
|
||||
assert len(all_passages) >= 3
|
||||
|
||||
def test_insert_passage_with_tags_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test inserting a passage with tags into an agent without embedding."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
passages = client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Important fact: The sky is blue due to Rayleigh scattering.",
|
||||
tags=["science", "physics", "important"],
|
||||
)
|
||||
|
||||
assert len(passages) == 1
|
||||
passage = passages[0]
|
||||
assert passage.embedding is None
|
||||
assert passage.tags is not None
|
||||
assert set(passage.tags) == {"science", "physics", "important"}
|
||||
|
||||
|
||||
class TestArchivalMemoryListWithoutEmbedding:
|
||||
"""Tests for listing archival memory without embeddings."""
|
||||
|
||||
def test_list_passages_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test listing passages from an agent without embedding."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Insert some passages first
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="First test passage",
|
||||
)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Second test passage",
|
||||
)
|
||||
|
||||
# List passages
|
||||
passages = client.agents.passages.list(agent_id=agent_id)
|
||||
|
||||
assert len(passages) >= 2
|
||||
|
||||
for passage in passages:
|
||||
# Verify embeddings are None
|
||||
assert passage.embedding is None
|
||||
|
||||
def test_list_passages_with_search_filter(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test listing passages with text search filter."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Insert passages with distinctive content
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Apple is a fruit that grows on trees.",
|
||||
)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Python is a programming language.",
|
||||
)
|
||||
|
||||
# Search for passages containing "fruit"
|
||||
passages = client.agents.passages.list(
|
||||
agent_id=agent_id,
|
||||
search="fruit",
|
||||
)
|
||||
|
||||
# Should find the apple passage
|
||||
assert len(passages) >= 1
|
||||
assert any("fruit" in p.text.lower() for p in passages)
|
||||
|
||||
|
||||
class TestArchivalMemorySearchWithoutEmbedding:
|
||||
"""Tests for searching archival memory without embeddings (text-based search)."""
|
||||
|
||||
def test_search_passages_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test searching passages using text search (no embeddings)."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Insert test passages
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="The capital of France is Paris.",
|
||||
)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Tokyo is the capital of Japan.",
|
||||
)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Python is a popular programming language.",
|
||||
)
|
||||
|
||||
# Search for passages about capitals
|
||||
results = client.agents.passages.search(
|
||||
agent_id=agent_id,
|
||||
query="capital",
|
||||
)
|
||||
|
||||
# Should find passages about capitals (text search)
|
||||
assert results is not None
|
||||
# Check results structure - might be a response object
|
||||
if hasattr(results, "results"):
|
||||
assert len(results.results) >= 1
|
||||
elif hasattr(results, "__len__"):
|
||||
assert len(results) >= 0 # Might be empty if text search returns 0
|
||||
|
||||
def test_global_passage_search_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test global passage search endpoint for agent without embedding."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Insert a passage
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Unique test content for global search testing xyz123.",
|
||||
)
|
||||
|
||||
# Use global passage search
|
||||
results = client.passages.search(
|
||||
query="xyz123",
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
# Should find the passage using text search
|
||||
assert results is not None
|
||||
|
||||
|
||||
class TestArchivalMemoryDeleteWithoutEmbedding:
|
||||
"""Tests for deleting archival memory without embeddings."""
|
||||
|
||||
def test_delete_passage_without_embedding(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test deleting a passage from an agent without embedding."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Insert a passage
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
passages = client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text="Passage to be deleted",
|
||||
)
|
||||
|
||||
passage_id = passages[0].id
|
||||
|
||||
# Delete the passage
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
client.agents.passages.delete(
|
||||
agent_id=agent_id,
|
||||
memory_id=passage_id,
|
||||
)
|
||||
|
||||
# Verify it's deleted - should not appear in list
|
||||
remaining = client.agents.passages.list(agent_id=agent_id)
|
||||
|
||||
assert all(p.id != passage_id for p in remaining)
|
||||
|
||||
|
||||
class TestComparisonWithAndWithoutEmbedding:
|
||||
"""Compare behavior between agents with and without embedding config."""
|
||||
|
||||
def test_passage_insert_comparison(
|
||||
self,
|
||||
client: LettaSDKClient,
|
||||
agent_without_embedding,
|
||||
agent_with_embedding,
|
||||
):
|
||||
"""Compare passage insertion between agents with/without embedding."""
|
||||
test_text = "Comparison test: This is identical content for both agents."
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
|
||||
# Insert into agent without embedding
|
||||
passages_no_embed = client.agents.passages.create(
|
||||
agent_id=agent_without_embedding.id,
|
||||
text=test_text,
|
||||
)
|
||||
|
||||
# Insert into agent with embedding
|
||||
passages_with_embed = client.agents.passages.create(
|
||||
agent_id=agent_with_embedding.id,
|
||||
text=test_text,
|
||||
)
|
||||
|
||||
# Both should succeed
|
||||
assert len(passages_no_embed) == 1
|
||||
assert len(passages_with_embed) == 1
|
||||
|
||||
# Text should be identical
|
||||
assert passages_no_embed[0].text == passages_with_embed[0].text
|
||||
|
||||
# Embedding status should differ
|
||||
assert passages_no_embed[0].embedding is None
|
||||
assert passages_with_embed[0].embedding is not None
|
||||
|
||||
def test_list_passages_comparison(
|
||||
self,
|
||||
client: LettaSDKClient,
|
||||
agent_without_embedding,
|
||||
agent_with_embedding,
|
||||
):
|
||||
"""Compare passage listing between agents with/without embedding."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
|
||||
# Insert passages into both agents
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_without_embedding.id,
|
||||
text="Test passage for listing comparison",
|
||||
)
|
||||
client.agents.passages.create(
|
||||
agent_id=agent_with_embedding.id,
|
||||
text="Test passage for listing comparison",
|
||||
)
|
||||
|
||||
# List from both agents
|
||||
passages_no_embed = client.agents.passages.list(agent_id=agent_without_embedding.id)
|
||||
passages_with_embed = client.agents.passages.list(agent_id=agent_with_embedding.id)
|
||||
|
||||
# Both should return passages
|
||||
assert len(passages_no_embed) >= 1
|
||||
assert len(passages_with_embed) >= 1
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Edge cases and error handling for embedding-optional feature."""
|
||||
|
||||
def test_empty_archival_memory_search(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test searching an empty archival memory."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
# Search without any passages - should return empty, not error
|
||||
results = client.agents.passages.search(
|
||||
agent_id=agent_id,
|
||||
query="anything",
|
||||
)
|
||||
|
||||
# Should return empty results, not raise an error
|
||||
assert results is not None
|
||||
|
||||
def test_passage_with_special_characters(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test inserting passages with special characters."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
special_text = "Special chars: @#$%^&*() 日本語 émojis 🎉 <script>alert('xss')</script>"
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
passages = client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text=special_text,
|
||||
)
|
||||
|
||||
assert len(passages) == 1
|
||||
assert passages[0].text == special_text
|
||||
assert passages[0].embedding is None
|
||||
|
||||
def test_very_long_passage(self, client: LettaSDKClient, agent_without_embedding):
|
||||
"""Test inserting a very long passage."""
|
||||
agent_id = agent_without_embedding.id
|
||||
|
||||
# Create a long text (10KB)
|
||||
long_text = "This is a test. " * 1000
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
passages = client.agents.passages.create(
|
||||
agent_id=agent_id,
|
||||
text=long_text,
|
||||
)
|
||||
|
||||
assert len(passages) >= 1 # Might be chunked
|
||||
# First passage should have no embedding
|
||||
assert passages[0].embedding is None
|
||||
@@ -68,7 +68,8 @@ def client() -> LettaSDKClient:
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
# Use 60s timeout to allow for provider model syncing during server startup
|
||||
wait_for_server(server_url, timeout=60)
|
||||
print("Running client tests with server:", server_url)
|
||||
client = LettaSDKClient(base_url=server_url)
|
||||
yield client
|
||||
|
||||
Reference in New Issue
Block a user