feat: Support embedding config on the archive [LET-5832] (#5714)
* Add embedding config field to archives * Fix alembic script * Simplify archive manager * Fern autogen * Fix failing tests * Fix alembic
This commit is contained in:
committed by
Caren Thomas
parent
c7c0d7507c
commit
e7e86124f9
@@ -0,0 +1,83 @@
|
||||
"""Add embedding config field to Archives table
|
||||
|
||||
Revision ID: f6cd5a1e519d
|
||||
Revises: c6c43222e2de
|
||||
Create Date: 2025-10-23 16:33:53.661122
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
import letta.orm
|
||||
from alembic import op
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f6cd5a1e519d"
|
||||
down_revision: Union[str, None] = "c6c43222e2de"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# step 1: add column as nullable
|
||||
op.add_column("archives", sa.Column("embedding_config", letta.orm.custom_columns.EmbeddingConfigColumn(), nullable=True))
|
||||
|
||||
# step 2: backfill existing archives with embedding configs in batches
|
||||
connection = op.get_bind()
|
||||
|
||||
# default embedding config for archives without passages
|
||||
default_config = EmbeddingConfig.default_config(model_name="letta")
|
||||
default_embedding_config = default_config.model_dump()
|
||||
|
||||
batch_size = 100
|
||||
processed = 0
|
||||
|
||||
# process in batches until no more archives need backfilling
|
||||
while True:
|
||||
archives = connection.execute(
|
||||
text("SELECT id FROM archives WHERE embedding_config IS NULL LIMIT :batch_size"), {"batch_size": batch_size}
|
||||
).fetchall()
|
||||
|
||||
if not archives:
|
||||
break
|
||||
|
||||
for archive in archives:
|
||||
archive_id = archive[0]
|
||||
|
||||
# check if archive has passages
|
||||
first_passage = connection.execute(
|
||||
text("SELECT embedding_config FROM archival_passages WHERE archive_id = :archive_id AND is_deleted = FALSE LIMIT 1"),
|
||||
{"archive_id": archive_id},
|
||||
).fetchone()
|
||||
|
||||
if first_passage and first_passage[0]:
|
||||
embedding_config = first_passage[0]
|
||||
else:
|
||||
embedding_config = default_embedding_config
|
||||
|
||||
# serialize the embedding config to JSON string for raw SQL
|
||||
config_json = json.dumps(embedding_config)
|
||||
|
||||
connection.execute(
|
||||
text("UPDATE archives SET embedding_config = :config WHERE id = :archive_id"),
|
||||
{"config": config_json, "archive_id": archive_id},
|
||||
)
|
||||
|
||||
processed += len(archives)
|
||||
print(f"Backfilled {processed} archives so far...")
|
||||
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
print(f"Backfill complete. Total archives processed: {processed}")
|
||||
|
||||
# step 3: make column non-nullable
|
||||
op.alter_column("archives", "embedding_config", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("archives", "embedding_config")
|
||||
@@ -18582,6 +18582,10 @@
|
||||
"description": "The vector database provider used for this archive's passages",
|
||||
"default": "native"
|
||||
},
|
||||
"embedding_config": {
|
||||
"$ref": "#/components/schemas/EmbeddingConfig",
|
||||
"description": "Embedding configuration for passages in this archive"
|
||||
},
|
||||
"metadata": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -18605,7 +18609,12 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["created_at", "name", "organization_id"],
|
||||
"required": [
|
||||
"created_at",
|
||||
"name",
|
||||
"organization_id",
|
||||
"embedding_config"
|
||||
],
|
||||
"title": "Archive",
|
||||
"description": "Representation of an archive - a collection of archival passages that can be shared between agents.\n\nParameters:\n id (str): The unique identifier of the archive.\n name (str): The name of the archive.\n description (str): A description of the archive.\n organization_id (str): The organization this archive belongs to.\n created_at (datetime): The creation date of the archive.\n metadata (dict): Additional metadata for the archive."
|
||||
},
|
||||
@@ -18615,6 +18624,10 @@
|
||||
"type": "string",
|
||||
"title": "Name"
|
||||
},
|
||||
"embedding_config": {
|
||||
"$ref": "#/components/schemas/EmbeddingConfig",
|
||||
"description": "Embedding configuration for the archive"
|
||||
},
|
||||
"description": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -18628,7 +18641,7 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"required": ["name", "embedding_config"],
|
||||
"title": "ArchiveCreateRequest",
|
||||
"description": "Request model for creating an archive.\n\nIntentionally excludes vector_db_provider. These are derived internally (vector DB provider from env)."
|
||||
},
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
from sqlalchemy import JSON, Enum, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.archive import Archive as PydanticArchive
|
||||
@@ -45,6 +46,9 @@ class Archive(SqlalchemyBase, OrganizationMixin):
|
||||
default=VectorDBProvider.NATIVE,
|
||||
doc="The vector database provider used for this archive's passages",
|
||||
)
|
||||
embedding_config: Mapped[dict] = mapped_column(
|
||||
EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive"
|
||||
)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
|
||||
_vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
@@ -16,6 +17,7 @@ class ArchiveBase(OrmMetadataBase):
|
||||
vector_db_provider: VectorDBProvider = Field(
|
||||
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
|
||||
)
|
||||
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive")
|
||||
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta import AgentState
|
||||
from letta.schemas.agent import AgentRelationships
|
||||
from letta.schemas.archive import Archive as PydanticArchive, ArchiveBase
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.validators import AgentId, ArchiveId, PassageId
|
||||
@@ -20,6 +23,7 @@ class ArchiveCreateRequest(BaseModel):
|
||||
"""
|
||||
|
||||
name: str
|
||||
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for the archive")
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@@ -45,6 +49,7 @@ async def create_archive(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.archive_manager.create_archive_async(
|
||||
name=archive.name,
|
||||
embedding_config=archive.embedding_config,
|
||||
description=archive.description,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.archive import Archive as PydanticArchive
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -28,6 +29,7 @@ class ArchiveManager:
|
||||
async def create_archive_async(
|
||||
self,
|
||||
name: str,
|
||||
embedding_config: EmbeddingConfig,
|
||||
description: Optional[str] = None,
|
||||
actor: PydanticUser = None,
|
||||
) -> PydanticArchive:
|
||||
@@ -42,6 +44,7 @@ class ArchiveManager:
|
||||
description=description,
|
||||
organization_id=actor.organization_id,
|
||||
vector_db_provider=vector_db_provider,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
await archive.create_async(session, actor=actor)
|
||||
return archive.to_pydantic()
|
||||
@@ -299,11 +302,9 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_or_create_default_archive_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_name: Optional[str] = None,
|
||||
agent_state: PydanticAgentState,
|
||||
actor: PydanticUser = None,
|
||||
) -> PydanticArchive:
|
||||
"""Get the agent's default archive, creating one if it doesn't exist."""
|
||||
@@ -315,14 +316,14 @@ class ArchiveManager:
|
||||
agent_manager = AgentManager()
|
||||
|
||||
archive_ids = await agent_manager.get_agent_archive_ids_async(
|
||||
agent_id=agent_id,
|
||||
agent_id=agent_state.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if archive_ids:
|
||||
# TODO: Remove this check once we support multiple archives per agent
|
||||
if len(archive_ids) > 1:
|
||||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported")
|
||||
raise ValueError(f"Agent {agent_state.id} has multiple archives, which is not yet supported")
|
||||
# Get the archive
|
||||
archive = await self.get_archive_by_id_async(
|
||||
archive_id=archive_ids[0],
|
||||
@@ -331,9 +332,10 @@ class ArchiveManager:
|
||||
return archive
|
||||
|
||||
# Create a default archive for this agent
|
||||
archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive"
|
||||
archive_name = f"{agent_state.name}'s Archive"
|
||||
archive = await self.create_archive_async(
|
||||
name=archive_name,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
description="Default archive created automatically",
|
||||
actor=actor,
|
||||
)
|
||||
@@ -341,7 +343,7 @@ class ArchiveManager:
|
||||
try:
|
||||
# Attach the agent to the archive as owner
|
||||
await self.attach_agent_to_archive_async(
|
||||
agent_id=agent_id,
|
||||
agent_id=agent_state.id,
|
||||
archive_id=archive.id,
|
||||
is_owner=True,
|
||||
actor=actor,
|
||||
@@ -350,12 +352,12 @@ class ArchiveManager:
|
||||
except IntegrityError:
|
||||
# race condition: another concurrent request already created and attached an archive
|
||||
# clean up the orphaned archive we just created
|
||||
logger.info(f"Race condition detected for agent {agent_id}, cleaning up orphaned archive {archive.id}")
|
||||
logger.info(f"Race condition detected for agent {agent_state.id}, cleaning up orphaned archive {archive.id}")
|
||||
await self.delete_archive_async(archive_id=archive.id, actor=actor)
|
||||
|
||||
# fetch the existing archive that was created by the concurrent request
|
||||
archive_ids = await agent_manager.get_agent_archive_ids_async(
|
||||
agent_id=agent_id,
|
||||
agent_id=agent_state.id,
|
||||
actor=actor,
|
||||
)
|
||||
if archive_ids:
|
||||
|
||||
@@ -437,9 +437,7 @@ class PassageManager:
|
||||
)
|
||||
|
||||
# Get or create the default archive for the agent
|
||||
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=agent_state.id, agent_name=agent_state.name, actor=actor
|
||||
)
|
||||
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
|
||||
|
||||
|
||||
@@ -189,7 +189,9 @@ def test_should_use_tpuf_with_settings():
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_creation_with_tpuf_enabled(server, default_user, enable_turbopuffer):
|
||||
"""Test that archives are created with correct vector_db_provider when TPUF is enabled"""
|
||||
archive = await server.archive_manager.create_archive_async(name="Test Archive with TPUF", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="Test Archive with TPUF", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
assert archive.vector_db_provider == VectorDBProvider.TPUF
|
||||
# TODO: Add cleanup when delete_archive method is available
|
||||
|
||||
@@ -197,7 +199,9 @@ async def test_archive_creation_with_tpuf_enabled(server, default_user, enable_t
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_creation_with_tpuf_disabled(server, default_user, disable_turbopuffer):
|
||||
"""Test that archives default to NATIVE when TPUF is disabled"""
|
||||
archive = await server.archive_manager.create_archive_async(name="Test Archive without TPUF", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="Test Archive without TPUF", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
assert archive.vector_db_provider == VectorDBProvider.NATIVE
|
||||
# TODO: Add cleanup when delete_archive method is available
|
||||
|
||||
@@ -208,7 +212,9 @@ async def test_dual_write_and_query_with_real_tpuf(server, default_user, sarah_a
|
||||
"""Test that passages are written to both SQL and Turbopuffer with real connection and can be queried"""
|
||||
|
||||
# Create a TPUF-enabled archive
|
||||
archive = await server.archive_manager.create_archive_async(name="Test TPUF Archive for Real Dual Write", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="Test TPUF Archive for Real Dual Write", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
assert archive.vector_db_provider == VectorDBProvider.TPUF
|
||||
|
||||
# Attach the agent to the archive
|
||||
@@ -351,9 +357,7 @@ async def test_native_only_operations(server, default_user, sarah_agent, disable
|
||||
"""Test that operations work correctly when using only native PostgreSQL"""
|
||||
|
||||
# Create archive (should be NATIVE since turbopuffer is disabled)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
assert archive.vector_db_provider == VectorDBProvider.NATIVE
|
||||
|
||||
# Insert passages - should only write to SQL
|
||||
@@ -1833,7 +1837,9 @@ async def test_message_date_filtering_with_real_tpuf(enable_message_embedding, d
|
||||
async def test_archive_namespace_tracking(server, default_user, enable_turbopuffer):
|
||||
"""Test that archive namespaces are properly tracked in database"""
|
||||
# Create an archive
|
||||
archive = await server.archive_manager.create_archive_async(name="Test Archive for Namespace", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="Test Archive for Namespace", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
|
||||
# Get namespace - should be generated and stored
|
||||
namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id)
|
||||
@@ -1854,7 +1860,9 @@ async def test_archive_namespace_tracking(server, default_user, enable_turbopuff
|
||||
async def test_namespace_consistency_with_tpuf_client(server, default_user, enable_turbopuffer):
|
||||
"""Test that the namespace from managers matches what tpuf_client would generate"""
|
||||
# Create archive and agent
|
||||
archive = await server.archive_manager.create_archive_async(name="Test Consistency Archive", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="Test Consistency Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
|
||||
# Get namespace from manager
|
||||
archive_namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id)
|
||||
@@ -1875,14 +1883,18 @@ async def test_environment_namespace_variation(server, default_user):
|
||||
try:
|
||||
settings.environment = None
|
||||
|
||||
archive = await server.archive_manager.create_archive_async(name="No Env Archive", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="No Env Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
namespace_no_env = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id)
|
||||
assert namespace_no_env == f"archive_{archive.id}"
|
||||
|
||||
# Test with environment
|
||||
settings.environment = "TESTING"
|
||||
|
||||
archive2 = await server.archive_manager.create_archive_async(name="With Env Archive", actor=default_user)
|
||||
archive2 = await server.archive_manager.create_archive_async(
|
||||
name="With Env Archive", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
namespace_with_env = await server.archive_manager.get_or_set_vector_db_namespace_async(archive2.id)
|
||||
assert namespace_with_env == f"archive_{archive2.id}_testing"
|
||||
|
||||
|
||||
@@ -395,7 +395,9 @@ async def comprehensive_test_agent_fixture(server: SyncServer, default_user, pri
|
||||
@pytest.fixture
|
||||
async def default_archive(server: SyncServer, default_user):
|
||||
"""Create and return a default archive."""
|
||||
archive = await server.archive_manager.create_archive_async("test", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
"test", embedding_config=EmbeddingConfig.default_config(provider="openai"), actor=default_user
|
||||
)
|
||||
yield archive
|
||||
|
||||
|
||||
@@ -403,9 +405,7 @@ async def default_archive(server: SyncServer, default_user):
|
||||
async def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
"""Create an agent passage."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
|
||||
@@ -54,7 +54,7 @@ from letta.orm import Base, Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
||||
from letta.schemas.agent import AgentRelationships, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentRelationships, AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import (
|
||||
@@ -109,7 +109,10 @@ from tests.utils import random_string
|
||||
async def test_archive_manager_delete_archive_async(server: SyncServer, default_user):
|
||||
"""Test the delete_archive_async function."""
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_archive_to_delete", description="This archive will be deleted", actor=default_user
|
||||
name="test_archive_to_delete",
|
||||
description="This archive will be deleted",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
retrieved_archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user)
|
||||
@@ -125,7 +128,10 @@ async def test_archive_manager_delete_archive_async(server: SyncServer, default_
|
||||
async def test_archive_manager_get_agents_for_archive_async(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test getting all agents that have access to an archive."""
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="shared_archive", description="Archive shared by multiple agents", actor=default_user
|
||||
name="shared_archive",
|
||||
description="Archive shared by multiple agents",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
@@ -187,7 +193,10 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau
|
||||
|
||||
# First, create an archive that will be attached by a "concurrent" request
|
||||
concurrent_archive = await server.archive_manager.create_archive_async(
|
||||
name=f"{agent.name}'s Archive", description="Default archive created automatically", actor=default_user
|
||||
name=f"{agent.name}'s Archive",
|
||||
description="Default archive created automatically",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
@@ -206,9 +215,7 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau
|
||||
|
||||
with patch.object(server.archive_manager, "create_archive_async", side_effect=track_create):
|
||||
with patch.object(server.archive_manager, "attach_agent_to_archive_async", side_effect=failing_attach):
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=agent.id, agent_name=agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent, actor=default_user)
|
||||
|
||||
assert archive is not None
|
||||
assert archive.id == concurrent_archive.id # Should return the existing archive
|
||||
@@ -230,9 +237,7 @@ async def test_archive_manager_race_condition_handling(server: SyncServer, defau
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_manager_get_agent_from_passage_async(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test getting the agent ID that owns a passage through its archive."""
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
@@ -250,7 +255,7 @@ async def test_archive_manager_get_agent_from_passage_async(server: SyncServer,
|
||||
assert agent_id == sarah_agent.id
|
||||
|
||||
orphan_archive = await server.archive_manager.create_archive_async(
|
||||
name="orphan_archive", description="Archive with no agents", actor=default_user
|
||||
name="orphan_archive", description="Archive with no agents", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
orphan_passage = await server.passage_manager.create_agent_passage_async(
|
||||
@@ -278,7 +283,7 @@ async def test_archive_manager_create_archive_async(server: SyncServer, default_
|
||||
"""Test creating a new archive with various parameters."""
|
||||
# test creating with name and description
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_archive_basic", description="Test archive description", actor=default_user
|
||||
name="test_archive_basic", description="Test archive description", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
assert archive.name == "test_archive_basic"
|
||||
@@ -287,7 +292,9 @@ async def test_archive_manager_create_archive_async(server: SyncServer, default_
|
||||
assert archive.id is not None
|
||||
|
||||
# test creating without description
|
||||
archive2 = await server.archive_manager.create_archive_async(name="test_archive_no_desc", actor=default_user)
|
||||
archive2 = await server.archive_manager.create_archive_async(
|
||||
name="test_archive_no_desc", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
assert archive2.name == "test_archive_no_desc"
|
||||
assert archive2.description is None
|
||||
@@ -303,7 +310,7 @@ async def test_archive_manager_get_archive_by_id_async(server: SyncServer, defau
|
||||
"""Test retrieving an archive by its ID."""
|
||||
# create an archive
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_get_by_id", description="Archive to test get_by_id", actor=default_user
|
||||
name="test_get_by_id", description="Archive to test get_by_id", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# retrieve the archive
|
||||
@@ -327,7 +334,7 @@ async def test_archive_manager_update_archive_async(server: SyncServer, default_
|
||||
"""Test updating archive name and description."""
|
||||
# create an archive
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="original_name", description="original description", actor=default_user
|
||||
name="original_name", description="original description", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# update name only
|
||||
@@ -370,7 +377,7 @@ async def test_archive_manager_list_archives_async(server: SyncServer, default_u
|
||||
archives = []
|
||||
for i in range(5):
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name=f"list_test_archive_{i}", description=f"Description {i}", actor=default_user
|
||||
name=f"list_test_archive_{i}", description=f"Description {i}", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
archives.append(archive)
|
||||
|
||||
@@ -413,8 +420,12 @@ async def test_archive_manager_list_archives_async(server: SyncServer, default_u
|
||||
async def test_archive_manager_attach_agent_to_archive_async(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test attaching agents to archives with ownership settings."""
|
||||
# create archives
|
||||
archive1 = await server.archive_manager.create_archive_async(name="archive_for_attachment_1", actor=default_user)
|
||||
archive2 = await server.archive_manager.create_archive_async(name="archive_for_attachment_2", actor=default_user)
|
||||
archive1 = await server.archive_manager.create_archive_async(
|
||||
name="archive_for_attachment_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
archive2 = await server.archive_manager.create_archive_async(
|
||||
name="archive_for_attachment_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# create another agent
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
@@ -466,7 +477,10 @@ async def test_archive_manager_detach_agent_from_archive_async(server: SyncServe
|
||||
"""Test detaching agents from archives."""
|
||||
# create archive and agents
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="archive_for_detachment", description="Test archive for detachment", actor=default_user
|
||||
name="archive_for_detachment",
|
||||
description="Test archive for detachment",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
@@ -539,7 +553,9 @@ async def test_archive_manager_detach_agent_from_archive_async(server: SyncServe
|
||||
async def test_archive_manager_attach_detach_idempotency(server: SyncServer, default_user):
|
||||
"""Test that attach and detach operations are idempotent."""
|
||||
# create archive and agent
|
||||
archive = await server.archive_manager.create_archive_async(name="idempotency_test_archive", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="idempotency_test_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
@@ -598,8 +614,12 @@ async def test_archive_manager_attach_detach_idempotency(server: SyncServer, def
|
||||
async def test_archive_manager_detach_with_multiple_archives(server: SyncServer, default_user):
|
||||
"""Test detaching an agent from one archive doesn't affect others."""
|
||||
# create two archives
|
||||
archive1 = await server.archive_manager.create_archive_async(name="multi_archive_1", actor=default_user)
|
||||
archive2 = await server.archive_manager.create_archive_async(name="multi_archive_2", actor=default_user)
|
||||
archive1 = await server.archive_manager.create_archive_async(
|
||||
name="multi_archive_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
archive2 = await server.archive_manager.create_archive_async(
|
||||
name="multi_archive_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# create two agents
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
@@ -663,7 +683,9 @@ async def test_archive_manager_detach_with_multiple_archives(server: SyncServer,
|
||||
async def test_archive_manager_detach_deleted_agent(server: SyncServer, default_user):
|
||||
"""Test behavior when detaching a deleted agent."""
|
||||
# create archive
|
||||
archive = await server.archive_manager.create_archive_async(name="test_deleted_agent_archive", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_deleted_agent_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# create and attach agent
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
@@ -701,7 +723,10 @@ async def test_archive_manager_cascade_delete_on_archive_deletion(server: SyncSe
|
||||
"""Test that deleting an archive cascades to delete relationships in archives_agents table."""
|
||||
# create archive
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="archive_to_be_deleted", description="This archive will be deleted to test CASCADE", actor=default_user
|
||||
name="archive_to_be_deleted",
|
||||
description="This archive will be deleted to test CASCADE",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# create multiple agents and attach them to the archive
|
||||
@@ -777,7 +802,10 @@ async def test_archive_manager_list_agents_with_pagination(server: SyncServer, d
|
||||
"""Test listing agents for an archive with pagination support."""
|
||||
# create archive
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="pagination_test_archive", description="Archive for testing pagination", actor=default_user
|
||||
name="pagination_test_archive",
|
||||
description="Archive for testing pagination",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# create multiple agents
|
||||
@@ -855,7 +883,9 @@ async def test_archive_manager_get_default_archive_for_agent_async(server: SyncS
|
||||
assert archive is None
|
||||
|
||||
# create and attach an archive
|
||||
created_archive = await server.archive_manager.create_archive_async(name="default_archive", actor=default_user)
|
||||
created_archive = await server.archive_manager.create_archive_async(
|
||||
name="default_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
await server.archive_manager.attach_agent_to_archive_async(
|
||||
agent_id=agent.id, archive_id=created_archive.id, is_owner=True, actor=default_user
|
||||
@@ -875,7 +905,9 @@ async def test_archive_manager_get_default_archive_for_agent_async(server: SyncS
|
||||
async def test_archive_manager_get_or_set_vector_db_namespace_async(server: SyncServer, default_user):
|
||||
"""Test getting or setting vector database namespace for an archive."""
|
||||
# create an archive
|
||||
archive = await server.archive_manager.create_archive_async(name="test_vector_namespace", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_vector_namespace", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# get/set namespace for the first time
|
||||
namespace = await server.archive_manager.get_or_set_vector_db_namespace_async(archive_id=archive.id)
|
||||
@@ -897,7 +929,10 @@ async def test_archive_manager_get_agents_with_include_parameter(server: SyncSer
|
||||
"""Test getting agents for an archive with include parameter to load relationships."""
|
||||
# create an archive
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_include_archive", description="Test archive for include parameter", actor=default_user
|
||||
name="test_include_archive",
|
||||
description="Test archive for include parameter",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# create agent without base tools (to avoid needing tools in test DB)
|
||||
@@ -961,7 +996,10 @@ async def test_archive_manager_delete_passage_from_archive_async(server: SyncSer
|
||||
"""Test deleting a passage from an archive."""
|
||||
# create archive
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_passage_deletion_archive", description="Archive for testing passage deletion", actor=default_user
|
||||
name="test_passage_deletion_archive",
|
||||
description="Archive for testing passage deletion",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# create passages
|
||||
@@ -1015,8 +1053,12 @@ async def test_archive_manager_delete_passage_from_archive_async(server: SyncSer
|
||||
async def test_archive_manager_delete_passage_from_wrong_archive(server: SyncServer, default_user):
|
||||
"""Test that deleting a passage from the wrong archive raises an error."""
|
||||
# create two archives
|
||||
archive1 = await server.archive_manager.create_archive_async(name="archive_1", actor=default_user)
|
||||
archive2 = await server.archive_manager.create_archive_async(name="archive_2", actor=default_user)
|
||||
archive1 = await server.archive_manager.create_archive_async(
|
||||
name="archive_1", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
archive2 = await server.archive_manager.create_archive_async(
|
||||
name="archive_2", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# create passage in archive1
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
@@ -1048,7 +1090,9 @@ async def test_archive_manager_delete_passage_from_wrong_archive(server: SyncSer
|
||||
async def test_archive_manager_delete_nonexistent_passage(server: SyncServer, default_user):
|
||||
"""Test that deleting a non-existent passage raises an error."""
|
||||
# create archive
|
||||
archive = await server.archive_manager.create_archive_async(name="test_nonexistent_passage_archive", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_nonexistent_passage_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
# attempt to delete non-existent passage (use valid UUID4 format)
|
||||
fake_passage_id = f"passage-{uuid.uuid4()}"
|
||||
@@ -1065,7 +1109,9 @@ async def test_archive_manager_delete_nonexistent_passage(server: SyncServer, de
|
||||
async def test_archive_manager_delete_passage_from_nonexistent_archive(server: SyncServer, default_user):
|
||||
"""Test that deleting a passage from a non-existent archive raises an error."""
|
||||
# create archive and passage
|
||||
archive = await server.archive_manager.create_archive_async(name="temp_archive", actor=default_user)
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="temp_archive", embedding_config=DEFAULT_EMBEDDING_CONFIG, actor=default_user
|
||||
)
|
||||
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
|
||||
@@ -54,7 +54,7 @@ from letta.orm import Base, Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import (
|
||||
@@ -245,9 +245,7 @@ async def test_agent_list_passages_vector_search(
|
||||
embed_model = mock_embed_model
|
||||
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
@@ -429,9 +427,7 @@ async def test_passage_cascade_deletion(
|
||||
async def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test creating an agent passage using the new agent-specific method."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
@@ -494,9 +490,7 @@ async def test_create_agent_passage_validation(server: SyncServer, default_user,
|
||||
)
|
||||
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Should fail if source_id is present
|
||||
with pytest.raises(ValueError, match="Agent passage cannot have source_id"):
|
||||
@@ -530,9 +524,7 @@ async def test_create_source_passage_validation(server: SyncServer, default_user
|
||||
)
|
||||
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Should fail if archive_id is present
|
||||
with pytest.raises(ValueError, match="Source passage cannot have archive_id"):
|
||||
@@ -554,9 +546,7 @@ async def test_create_source_passage_validation(server: SyncServer, default_user
|
||||
async def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test retrieving an agent passage using the new agent-specific method."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create an agent passage
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
@@ -608,9 +598,7 @@ async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sa
|
||||
"""Test that trying to get the wrong passage type with specific methods fails."""
|
||||
# Create an agent passage
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
agent_passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
@@ -650,9 +638,7 @@ async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sa
|
||||
async def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test updating an agent passage using the new agent-specific method."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create an agent passage
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
@@ -724,9 +710,7 @@ async def test_update_source_passage_specific(server: SyncServer, default_user,
|
||||
async def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test deleting an agent passage using the new agent-specific method."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create an agent passage
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
@@ -787,9 +771,7 @@ async def test_delete_source_passage_specific(server: SyncServer, default_user,
|
||||
async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test creating multiple agent passages using the new batch method."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
passages = [
|
||||
PydanticPassage(
|
||||
@@ -846,9 +828,7 @@ async def test_agent_passage_size(server: SyncServer, default_user, sarah_agent)
|
||||
initial_size = await server.passage_manager.agent_passage_size_async(actor=default_user, agent_id=sarah_agent.id)
|
||||
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create some agent passages
|
||||
for i in range(3):
|
||||
@@ -873,9 +853,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create passages with different tag combinations
|
||||
test_passages = [
|
||||
@@ -969,8 +947,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync
|
||||
|
||||
# Test 2: Verify unique tags for archive
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id,
|
||||
agent_name=sarah_agent.name,
|
||||
agent_state=sarah_agent,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -1194,8 +1171,7 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age
|
||||
|
||||
# Verify unique tags includes all special character tags
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id,
|
||||
agent_name=sarah_agent.name,
|
||||
agent_state=sarah_agent,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -1239,9 +1215,7 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age
|
||||
async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent):
|
||||
"""Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user)
|
||||
|
||||
# Create test passages with various content and tags
|
||||
test_data = [
|
||||
|
||||
Reference in New Issue
Block a user