feat: Improve performance on files related data models (#3285)

This commit is contained in:
Matthew Zhou
2025-07-10 18:00:35 -07:00
committed by GitHub
parent 208d6fefa9
commit c94b227a32
14 changed files with 483 additions and 139 deletions

View File

@@ -0,0 +1,52 @@
"""Write source_id directly to files agents
Revision ID: 495f3f474131
Revises: 47d2277e530d
Create Date: 2025-07-10 17:14:45.154738
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "495f3f474131"
down_revision: Union[str, None] = "47d2277e530d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Step 1: Add the column as nullable first
op.add_column("files_agents", sa.Column("source_id", sa.String(), nullable=True))
# Step 2: Backfill source_id from files table
connection = op.get_bind()
connection.execute(
sa.text(
"""
UPDATE files_agents
SET source_id = files.source_id
FROM files
WHERE files_agents.file_id = files.id
"""
)
)
# Step 3: Make the column NOT NULL now that it's populated
op.alter_column("files_agents", "source_id", nullable=False)
# Step 4: Add the foreign key constraint
op.create_foreign_key(None, "files_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "files_agents", type_="foreignkey")
op.drop_column("files_agents", "source_id")
# ### end Alembic commands ###

View File

@@ -1,5 +1,5 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey, Index, Integer, String, Text, UniqueConstraint, desc
from sqlalchemy.ext.asyncio import AsyncAttrs
@@ -11,10 +11,7 @@ from letta.schemas.enums import FileProcessingStatus
from letta.schemas.file import FileMetadata as PydanticFileMetadata
if TYPE_CHECKING:
from letta.orm.files_agents import FileAgent
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
from letta.orm.source import Source
pass
# TODO: Note that this is NOT organization scoped, this is potentially dangerous if we misuse this
@@ -64,18 +61,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
)
file_agents: Mapped[List["FileAgent"]] = relationship(
"FileAgent",
back_populates="file",
lazy="selectin",
cascade="all, delete-orphan",
passive_deletes=True, # ← add this
)
content: Mapped[Optional["FileContent"]] = relationship(
"FileContent",
uselist=False,

View File

@@ -12,7 +12,7 @@ from letta.schemas.block import Block as PydanticBlock
from letta.schemas.file import FileAgent as PydanticFileAgent
if TYPE_CHECKING:
from letta.orm.file import FileMetadata
pass
class FileAgent(SqlalchemyBase, OrganizationMixin):
@@ -55,6 +55,12 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
nullable=False,
doc="ID of the agent",
)
source_id: Mapped[str] = mapped_column(
String,
ForeignKey("sources.id", ondelete="CASCADE"),
nullable=False,
doc="ID of the source (denormalized from files.source_id)",
)
file_name: Mapped[str] = mapped_column(
String,
@@ -78,13 +84,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
back_populates="file_agents",
lazy="selectin",
)
file: Mapped["FileMetadata"] = relationship(
"FileMetadata",
foreign_keys=[file_id],
lazy="selectin",
back_populates="file_agents",
passive_deletes=True, # ← add this
)
# TODO: This is temporary as we figure out if we want FileBlock as a first class citizen
def to_pydantic_block(self) -> PydanticBlock:
@@ -99,8 +98,8 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
return PydanticBlock(
organization_id=self.organization_id,
value=visible_content,
label=self.file.file_name,
label=self.file_name, # use denormalized file_name instead of self.file.file_name
read_only=True,
metadata={"source_id": self.file.source_id},
metadata={"source_id": self.source_id}, # use denormalized source_id
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
)

View File

@@ -9,7 +9,6 @@ if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.agent_passage import AgentPassage
from letta.orm.block import Block
from letta.orm.file import FileMetadata
from letta.orm.group import Group
from letta.orm.identity import Identity
from letta.orm.llm_batch_item import LLMBatchItem
@@ -18,7 +17,6 @@ if TYPE_CHECKING:
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig
from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.source_passage import SourcePassage
from letta.orm.tool import Tool
from letta.orm.user import User
@@ -38,8 +36,6 @@ class Organization(SqlalchemyBase):
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
# mcp_servers: Mapped[List["MCPServer"]] = relationship("MCPServer", back_populates="organization", cascade="all, delete-orphan")
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
"SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
)

View File

@@ -49,11 +49,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from")
@declared_attr
def file(cls) -> Mapped["FileMetadata"]:
"""Relationship to file"""
return relationship("FileMetadata", back_populates="source_passages", lazy="selectin")
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="source_passages", lazy="selectin")
@@ -74,11 +69,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
{"extend_existing": True},
)
@declared_attr
def source(cls) -> Mapped["Source"]:
"""Relationship to source"""
return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True)
class AgentPassage(BasePassage, AgentMixin):
"""Passages created by agents as archival memories"""

View File

@@ -1,9 +1,8 @@
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from sqlalchemy import JSON, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm import FileMetadata
from letta.orm.custom_columns import EmbeddingConfigColumn
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
@@ -11,10 +10,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.source import Source as PydanticSource
if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
pass
class Source(SqlalchemyBase, OrganizationMixin):
@@ -34,16 +30,3 @@ class Source(SqlalchemyBase, OrganizationMixin):
instructions: Mapped[str] = mapped_column(nullable=True, doc="instructions for how to use the source")
embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.")
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship(
"Agent",
secondary="sources_agents",
back_populates="sources",
lazy="selectin",
cascade="save-update", # Only propagate save and update operations
passive_deletes=True, # Let the database handle deletions
)

View File

@@ -85,6 +85,7 @@ class FileAgent(FileAgentBase):
)
agent_id: str = Field(..., description="Unique identifier of the agent.")
file_id: str = Field(..., description="Unique identifier of the file.")
source_id: str = Field(..., description="Unique identifier of the source (denormalized from files.source_id).")
file_name: str = Field(..., description="Name of the file.")
is_open: bool = Field(True, description="True if the agent currently has the file open.")
visible_content: Optional[str] = Field(

View File

@@ -107,6 +107,31 @@ class AgentManager:
self.identity_manager = IdentityManager()
self.file_agent_manager = FileAgentManager()
async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None:
"""
Validate that an agent exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
agent_id: ID of the agent to validate
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
agent_check_query = sa.text(
"""
SELECT 1 FROM agents
WHERE id = :agent_id
AND organization_id = :org_id
AND is_deleted = false
"""
)
agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id})
if not agent_exists.fetchone():
raise NoResultFound(f"Agent with ID {agent_id} not found")
@staticmethod
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
@@ -1133,8 +1158,8 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def get_agent_by_id_async(
self,
agent_id: str,
@@ -1853,28 +1878,8 @@ class AgentManager:
# update agent in-context message IDs
await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor)
@trace_method
@enforce_types
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
Args:
agent_id: ID of the agent to list sources for
actor: User performing the action
Returns:
List[str]: List of source IDs attached to the agent
"""
with db_registry.session() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
@trace_method
@enforce_types
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
@@ -1885,41 +1890,31 @@ class AgentManager:
Returns:
List[str]: List of source IDs attached to the agent
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
async with db_registry.async_session() as session:
# Verify agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Validate agent exists and user has access
await self._validate_agent_exists_async(session, agent_id, actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
# Use raw SQL to efficiently fetch sources - much faster than lazy loading
# Fast query without relationship loading
query = (
select(SourceModel)
.join(SourcesAgents, SourceModel.id == SourcesAgents.source_id)
.where(
SourcesAgents.agent_id == agent_id,
SourceModel.organization_id == actor.organization_id,
SourceModel.is_deleted == False,
)
.order_by(SourceModel.created_at.desc(), SourceModel.id)
)
@trace_method
@enforce_types
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a source from an agent.
result = await session.execute(query)
sources = result.scalars().all()
Args:
agent_id: ID of the agent to detach the source from
source_id: ID of the source to detach
actor: User performing the action
"""
with db_registry.session() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Remove the source from the relationship
remaining_sources = [s for s in agent.sources if s.id != source_id]
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
# Update the sources relationship
agent.sources = remaining_sources
# Commit the changes
agent.update(session, actor=actor)
return agent.to_pydantic()
return [source.to_pydantic() for source in sources]
@trace_method
@enforce_types
@@ -1931,22 +1926,29 @@ class AgentManager:
agent_id: ID of the agent to detach the source from
source_id: ID of the source to detach
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
async with db_registry.async_session() as session:
# Verify agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Validate agent exists and user has access
await self._validate_agent_exists_async(session, agent_id, actor)
# Remove the source from the relationship
remaining_sources = [s for s in agent.sources if s.id != source_id]
# Check if the source is actually attached to this agent using junction table
attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
attachment_result = await session.execute(attachment_check_query)
attachment = attachment_result.scalar_one_or_none()
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
if not attachment:
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
else:
# Delete the association directly from the junction table
delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
await session.execute(delete_query)
await session.commit()
# Update the sources relationship
agent.sources = remaining_sources
# Commit the changes
await agent.update_async(session, actor=actor)
# Get agent without loading relationships for return value
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
return await agent.to_pydantic_async()
# ======================================================================================================================

View File

@@ -25,7 +25,6 @@ class OpenAIEmbedder(BaseEmbedder):
else EmbeddingConfig.default_config(model_name="letta")
)
self.embedding_config = embedding_config or self.default_embedding_config
self.max_concurrent_requests = 20
# TODO: Unify to global OpenAI client
self.client: OpenAIClient = cast(
@@ -48,9 +47,55 @@ class OpenAIEmbedder(BaseEmbedder):
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
},
)
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
try:
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
except Exception as e:
# if it's a token limit error and we can split, do it
if self._is_token_limit_error(e) and len(batch) > 1:
logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying")
log_event(
"embedder.batch_split_retry",
{
"original_batch_size": len(batch),
"error": str(e),
"split_size": len(batch) // 2,
},
)
# split batch in half
mid = len(batch) // 2
batch1 = batch[:mid]
batch1_indices = batch_indices[:mid]
batch2 = batch[mid:]
batch2_indices = batch_indices[mid:]
# retry with smaller batches
result1 = await self._embed_batch(batch1, batch1_indices)
result2 = await self._embed_batch(batch2, batch2_indices)
return result1 + result2
else:
# re-raise for other errors or if batch size is already 1
raise
def _is_token_limit_error(self, error: Exception) -> bool:
"""Check if the error is due to token limit exceeded"""
# convert to string and check for token limit patterns
error_str = str(error).lower()
# TODO: This is quite brittle, works for now
# check for the specific patterns we see in token limit errors
is_token_limit = (
"max_tokens_per_request" in error_str
or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str)
or "token limit" in error_str
or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str)
)
return is_token_limit
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
@@ -100,7 +145,7 @@ class OpenAIEmbedder(BaseEmbedder):
log_event(
"embedder.concurrent_processing_started",
{"concurrent_tasks": len(tasks), "max_concurrent_requests": self.max_concurrent_requests},
{"concurrent_tasks": len(tasks)},
)
results = await asyncio.gather(*tasks)
log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)})

View File

@@ -29,6 +29,7 @@ class FileAgentManager:
agent_id: str,
file_id: str,
file_name: str,
source_id: str,
actor: PydanticUser,
is_open: bool = True,
visible_content: Optional[str] = None,
@@ -47,7 +48,12 @@ class FileAgentManager:
if is_open:
# Use the efficient LRU + open method
closed_files, was_already_open = await self.enforce_max_open_files_and_open(
agent_id=agent_id, file_id=file_id, file_name=file_name, actor=actor, visible_content=visible_content or ""
agent_id=agent_id,
file_id=file_id,
file_name=file_name,
source_id=source_id,
actor=actor,
visible_content=visible_content or "",
)
# Get the updated file agent to return
@@ -85,6 +91,7 @@ class FileAgentManager:
agent_id=agent_id,
file_id=file_id,
file_name=file_name,
source_id=source_id,
organization_id=actor.organization_id,
is_open=is_open,
visible_content=visible_content,
@@ -327,7 +334,7 @@ class FileAgentManager:
@enforce_types
@trace_method
async def enforce_max_open_files_and_open(
self, *, agent_id: str, file_id: str, file_name: str, actor: PydanticUser, visible_content: str
self, *, agent_id: str, file_id: str, file_name: str, source_id: str, actor: PydanticUser, visible_content: str
) -> tuple[List[str], bool]:
"""
Efficiently handle LRU eviction and file opening in a single transaction.
@@ -336,6 +343,7 @@ class FileAgentManager:
agent_id: ID of the agent
file_id: ID of the file to open
file_name: Name of the file to open
source_id: ID of the source (denormalized from files.source_id)
actor: User performing the action
visible_content: Content to set for the opened file
@@ -418,6 +426,7 @@ class FileAgentManager:
agent_id=agent_id,
file_id=file_id,
file_name=file_name,
source_id=source_id,
organization_id=actor.organization_id,
is_open=True,
visible_content=visible_content,
@@ -516,6 +525,7 @@ class FileAgentManager:
agent_id=agent_id,
file_id=meta.id,
file_name=meta.file_name,
source_id=meta.source_id,
organization_id=actor.organization_id,
is_open=is_now_open,
visible_content=vc,

View File

@@ -1,8 +1,12 @@
import asyncio
from typing import List, Optional
from sqlalchemy import select
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.source import Source as SourceModel
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.source import Source as PydanticSource
@@ -104,9 +108,21 @@ class SourceManager:
# Verify source exists and user has permission to access it
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# The agents relationship is already loaded due to lazy="selectin" in the Source model
# and will be properly filtered by organization_id due to the OrganizationMixin
agents_orm = source.agents
# Use junction table query instead of relationship to avoid performance issues
query = (
select(AgentModel)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id if actor else True,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
result = await session.execute(query)
agents_orm = result.scalars().all()
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons

View File

@@ -180,7 +180,12 @@ class LettaFileToolExecutor(ToolExecutor):
# Handle LRU eviction and file opening
closed_files, was_already_open = await self.files_agents_manager.enforce_max_open_files_and_open(
agent_id=agent_state.id, file_id=file_id, file_name=file_name, actor=self.actor, visible_content=visible_content
agent_id=agent_state.id,
file_id=file_id,
file_name=file_name,
source_id=file.source_id,
actor=self.actor,
visible_content=visible_content,
)
opened_files.append(file_name)

View File

@@ -0,0 +1,219 @@
from unittest.mock import AsyncMock, Mock, patch
import openai
import pytest
from letta.errors import ErrorCode, LLMBadRequestError
from letta.schemas.embedding_config import EmbeddingConfig
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
class TestOpenAIEmbedder:
"""Test suite for OpenAI embedder functionality"""
@pytest.fixture
def mock_user(self):
"""Create a mock user for testing"""
user = Mock()
user.organization_id = "test_org_id"
return user
@pytest.fixture
def embedding_config(self):
"""Create a test embedding config"""
return EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=3, # small dimension for testing
embedding_chunk_size=300,
batch_size=2, # small batch size for testing
)
@pytest.fixture
def embedder(self, embedding_config):
"""Create OpenAI embedder with test config"""
with patch("letta.services.file_processor.embedder.openai_embedder.LLMClient.create") as mock_create:
mock_client = Mock()
mock_client.handle_llm_error = Mock()
mock_create.return_value = mock_client
embedder = OpenAIEmbedder(embedding_config)
embedder.client = mock_client
return embedder
@pytest.mark.asyncio
async def test_successful_embedding_generation(self, embedder, mock_user):
"""Test successful embedding generation for normal cases"""
# mock successful embedding response
mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
embedder.client.request_embeddings = AsyncMock(return_value=mock_embeddings)
chunks = ["chunk 1", "chunk 2"]
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert len(passages) == 2
assert passages[0].text == "chunk 1"
assert passages[1].text == "chunk 2"
# embeddings are padded to MAX_EMBEDDING_DIM, so check first 3 values
assert passages[0].embedding[:3] == [0.1, 0.2, 0.3]
assert passages[1].embedding[:3] == [0.4, 0.5, 0.6]
assert passages[0].file_id == file_id
assert passages[0].source_id == source_id
@pytest.mark.asyncio
async def test_token_limit_retry_splits_batch(self, embedder, mock_user):
"""Test that token limit errors trigger batch splitting and retry"""
# create a mock token limit error
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
# first call fails with token limit, subsequent calls succeed
call_count = 0
async def mock_request_embeddings(inputs, embedding_config):
nonlocal call_count
call_count += 1
if call_count == 1 and len(inputs) == 4: # first call with full batch
raise token_limit_error
elif len(inputs) == 2: # split batches succeed
return [[0.1, 0.2], [0.3, 0.4]] if call_count == 2 else [[0.5, 0.6], [0.7, 0.8]]
else:
return [[0.1, 0.2]] * len(inputs)
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
# should still get all 4 passages despite the retry
assert len(passages) == 4
assert all(len(p.embedding) == 4096 for p in passages) # padded to MAX_EMBEDDING_DIM
# verify multiple calls were made (original + retries)
assert call_count >= 2
@pytest.mark.asyncio
async def test_token_limit_error_detection(self, embedder):
"""Test various token limit error detection patterns"""
# test openai BadRequestError with proper structure
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
openai_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
assert embedder._is_token_limit_error(openai_error) is True
# test error with message but no code
mock_error_body_no_code = {"error": {"message": "max_tokens_per_request exceeded"}}
openai_error_no_code = openai.BadRequestError(
message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body_no_code
)
assert embedder._is_token_limit_error(openai_error_no_code) is True
# test fallback string detection
generic_error = Exception("Requested 100000 tokens, max 50000 tokens per request")
assert embedder._is_token_limit_error(generic_error) is True
# test non-token errors
other_error = Exception("Some other error")
assert embedder._is_token_limit_error(other_error) is False
auth_error = openai.AuthenticationError(
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
)
assert embedder._is_token_limit_error(auth_error) is False
@pytest.mark.asyncio
async def test_non_token_error_handling(self, embedder, mock_user):
"""Test that non-token errors are properly handled and re-raised"""
# create a non-token error
auth_error = openai.AuthenticationError(
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
)
# mock handle_llm_error to return a standardized error
handled_error = LLMBadRequestError(message="Handled error", code=ErrorCode.UNAUTHENTICATED)
embedder.client.handle_llm_error.return_value = handled_error
embedder.client.request_embeddings = AsyncMock(side_effect=auth_error)
chunks = ["chunk 1"]
file_id = "test_file"
source_id = "test_source"
with pytest.raises(LLMBadRequestError) as exc_info:
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert exc_info.value == handled_error
embedder.client.handle_llm_error.assert_called_once_with(auth_error)
@pytest.mark.asyncio
async def test_single_item_batch_no_retry(self, embedder, mock_user):
"""Test that single-item batches don't retry on token limit errors"""
# create a token limit error
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
handled_error = LLMBadRequestError(message="Handled token limit error", code=ErrorCode.INVALID_ARGUMENT)
embedder.client.handle_llm_error.return_value = handled_error
embedder.client.request_embeddings = AsyncMock(side_effect=token_limit_error)
chunks = ["very long chunk that exceeds token limit"]
file_id = "test_file"
source_id = "test_source"
with pytest.raises(LLMBadRequestError) as exc_info:
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert exc_info.value == handled_error
embedder.client.handle_llm_error.assert_called_once_with(token_limit_error)
@pytest.mark.asyncio
async def test_empty_chunks_handling(self, embedder, mock_user):
"""Test handling of empty chunks list"""
chunks = []
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert passages == []
# should not call request_embeddings for empty input
embedder.client.request_embeddings.assert_not_called()
@pytest.mark.asyncio
async def test_embedding_order_preservation(self, embedder, mock_user):
"""Test that embedding order is preserved even with retries"""
# set up embedder to split batches (batch_size=2)
embedder.embedding_config.batch_size = 2
# mock responses for each batch
async def mock_request_embeddings(inputs, embedding_config):
# return embeddings that correspond to input order
if inputs == ["chunk 1", "chunk 2"]:
return [[0.1, 0.1], [0.2, 0.2]]
elif inputs == ["chunk 3", "chunk 4"]:
return [[0.3, 0.3], [0.4, 0.4]]
else:
return [[0.1, 0.1]] * len(inputs)
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
# verify order is preserved
assert len(passages) == 4
assert passages[0].text == "chunk 1"
assert passages[0].embedding[:2] == [0.1, 0.1] # check first 2 values before padding
assert passages[1].text == "chunk 2"
assert passages[1].embedding[:2] == [0.2, 0.2]
assert passages[2].text == "chunk 3"
assert passages[2].embedding[:2] == [0.3, 0.3]
assert passages[3].text == "chunk 4"
assert passages[3].embedding[:2] == [0.4, 0.4]

View File

@@ -673,6 +673,7 @@ async def file_attachment(server, default_user, sarah_agent, default_file):
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="initial",
)
@@ -903,6 +904,7 @@ async def test_get_context_window_basic(
agent_id=created_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="hello",
)
@@ -7221,6 +7223,7 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="hello",
)
@@ -7243,6 +7246,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="first",
)
@@ -7252,6 +7256,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
is_open=False,
visible_content="second",
@@ -7326,15 +7331,28 @@ async def test_list_files_and_agents(
):
# default_file ↔ charles (open)
await server.file_agent_manager.attach_file(
agent_id=charles_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user
agent_id=charles_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
)
# default_file ↔ sarah (open)
await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
)
# another_file ↔ sarah (closed)
await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id, file_id=another_file.id, file_name=another_file.file_name, actor=default_user, is_open=False
agent_id=sarah_agent.id,
file_id=another_file.id,
file_name=another_file.file_name,
source_id=another_file.source_id,
actor=default_user,
is_open=False,
)
files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user)
@@ -7384,6 +7402,7 @@ async def test_org_scoping(
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
)
@@ -7420,6 +7439,7 @@ async def test_mark_access_bulk(server, default_user, sarah_agent, default_sourc
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"content for {file.file_name}",
)
@@ -7478,6 +7498,7 @@ async def test_lru_eviction_on_attach(server, default_user, sarah_agent, default
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"content for {file.file_name}",
)
@@ -7530,6 +7551,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
agent_id=sarah_agent.id,
file_id=files[i].id,
file_name=files[i].file_name,
source_id=files[i].source_id,
actor=default_user,
visible_content=f"content for {files[i].file_name}",
)
@@ -7539,6 +7561,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
agent_id=sarah_agent.id,
file_id=files[-1].id,
file_name=files[-1].file_name,
source_id=files[-1].source_id,
actor=default_user,
is_open=False,
visible_content=f"content for {files[-1].file_name}",
@@ -7555,7 +7578,12 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
# Now "open" the last file using the efficient method
closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open(
agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content"
agent_id=sarah_agent.id,
file_id=files[-1].id,
file_name=files[-1].file_name,
source_id=files[-1].source_id,
actor=default_user,
visible_content="updated content",
)
# Should have closed 1 file (the oldest one)
@@ -7603,6 +7631,7 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"content for {file.file_name}",
)
@@ -7617,7 +7646,12 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa
# "Reopen" the last file (which is already open)
closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open(
agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content"
agent_id=sarah_agent.id,
file_id=files[-1].id,
file_name=files[-1].file_name,
source_id=files[-1].source_id,
actor=default_user,
visible_content="updated content",
)
# Should not have closed any files since we're within the limit
@@ -7645,7 +7679,12 @@ async def test_last_accessed_at_updates_correctly(server, default_user, sarah_ag
file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text="test content")
file_agent, closed_files = await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, actor=default_user, visible_content="initial content"
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content="initial content",
)
initial_time = file_agent.last_accessed_at
@@ -7777,6 +7816,7 @@ async def test_attach_files_bulk_lru_eviction(server, default_user, sarah_agent,
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"existing content {i}",
)
@@ -7842,6 +7882,7 @@ async def test_attach_files_bulk_mixed_existing_new(server, default_user, sarah_
agent_id=sarah_agent.id,
file_id=existing_file.id,
file_name=existing_file.file_name,
source_id=existing_file.source_id,
actor=default_user,
visible_content="old content",
is_open=False, # Start as closed