From c94b227a3258903790cf13d85523d92a72297aad Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 10 Jul 2025 18:00:35 -0700 Subject: [PATCH] feat: Improve performance on files related data models (#3285) --- ...rite_source_id_directly_to_files_agents.py | 52 +++++ letta/orm/file.py | 19 +- letta/orm/files_agents.py | 19 +- letta/orm/organization.py | 4 - letta/orm/passage.py | 10 - letta/orm/source.py | 23 +- letta/schemas/file.py | 1 + letta/services/agent_manager.py | 124 +++++----- .../embedder/openai_embedder.py | 55 ++++- letta/services/files_agents_manager.py | 14 +- letta/services/source_manager.py | 22 +- .../tool_executor/files_tool_executor.py | 7 +- tests/test_file_processor.py | 219 ++++++++++++++++++ tests/test_managers.py | 53 ++++- 14 files changed, 483 insertions(+), 139 deletions(-) create mode 100644 alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py create mode 100644 tests/test_file_processor.py diff --git a/alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py b/alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py new file mode 100644 index 00000000..9319e99c --- /dev/null +++ b/alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py @@ -0,0 +1,52 @@ +"""Write source_id directly to files agents + +Revision ID: 495f3f474131 +Revises: 47d2277e530d +Create Date: 2025-07-10 17:14:45.154738 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "495f3f474131" +down_revision: Union[str, None] = "47d2277e530d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Step 1: Add the column as nullable first + op.add_column("files_agents", sa.Column("source_id", sa.String(), nullable=True)) + + # Step 2: Backfill source_id from files table + connection = op.get_bind() + connection.execute( + sa.text( + """ + UPDATE files_agents + SET source_id = files.source_id + FROM files + WHERE files_agents.file_id = files.id + """ + ) + ) + + # Step 3: Make the column NOT NULL now that it's populated + op.alter_column("files_agents", "source_id", nullable=False) + + # Step 4: Add the foreign key constraint + op.create_foreign_key(None, "files_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "files_agents", type_="foreignkey") + op.drop_column("files_agents", "source_id") + # ### end Alembic commands ### diff --git a/letta/orm/file.py b/letta/orm/file.py index 885731e5..8cae2448 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -1,5 +1,5 @@ import uuid -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from sqlalchemy import ForeignKey, Index, Integer, String, Text, UniqueConstraint, desc from sqlalchemy.ext.asyncio import AsyncAttrs @@ -11,10 +11,7 @@ from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata as PydanticFileMetadata if TYPE_CHECKING: - from letta.orm.files_agents import FileAgent - from letta.orm.organization import Organization - from letta.orm.passage import SourcePassage - from letta.orm.source import Source + pass # TODO: Note that this is NOT organization scoped, this is potentially dangerous if we misuse this @@ -64,18 +61,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.") # relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") - source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") - source_passages: Mapped[List["SourcePassage"]] = relationship( - "SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan" - ) - file_agents: Mapped[List["FileAgent"]] = relationship( - "FileAgent", - back_populates="file", - lazy="selectin", - cascade="all, delete-orphan", - passive_deletes=True, # ← add this - ) content: Mapped[Optional["FileContent"]] = relationship( "FileContent", uselist=False, diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py index f7398a91..d8fd5c2f 100644 --- a/letta/orm/files_agents.py +++ b/letta/orm/files_agents.py @@ -12,7 +12,7 @@ from letta.schemas.block import Block as PydanticBlock from letta.schemas.file import FileAgent as PydanticFileAgent if TYPE_CHECKING: - from letta.orm.file import FileMetadata + pass class FileAgent(SqlalchemyBase, OrganizationMixin): @@ -55,6 +55,12 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): nullable=False, doc="ID of the agent", ) + source_id: Mapped[str] = mapped_column( + String, + ForeignKey("sources.id", ondelete="CASCADE"), + nullable=False, + doc="ID of the source (denormalized from files.source_id)", + ) file_name: Mapped[str] = mapped_column( String, @@ -78,13 +84,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): back_populates="file_agents", lazy="selectin", ) - file: Mapped["FileMetadata"] = relationship( - "FileMetadata", - foreign_keys=[file_id], - lazy="selectin", - back_populates="file_agents", - passive_deletes=True, # ← add this - ) # TODO: This is temporary as we figure out if we want FileBlock as a first class citizen def to_pydantic_block(self) -> PydanticBlock: @@ -99,8 +98,8 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): return PydanticBlock( organization_id=self.organization_id, value=visible_content, - label=self.file.file_name, + label=self.file_name, # use denormalized file_name instead of self.file.file_name read_only=True, - metadata={"source_id": self.file.source_id}, + metadata={"source_id": self.source_id}, # use denormalized source_id limit=CORE_MEMORY_SOURCE_CHAR_LIMIT, ) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index e1937633..f5f65cb9 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from letta.orm.agent import Agent from letta.orm.agent_passage import AgentPassage from letta.orm.block import Block - from letta.orm.file import FileMetadata from letta.orm.group import Group from letta.orm.identity import Identity from letta.orm.llm_batch_item import LLMBatchItem @@ -18,7 +17,6 @@ if TYPE_CHECKING: from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable - from letta.orm.source import Source from letta.orm.source_passage import SourcePassage from letta.orm.tool import Tool from letta.orm.user import User @@ -38,8 +36,6 @@ class Organization(SqlalchemyBase): tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") # mcp_servers: Mapped[List["MCPServer"]] = relationship("MCPServer", back_populates="organization", cascade="all, delete-orphan") blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan") - sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") - files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan") sandbox_configs: Mapped[List["SandboxConfig"]] = relationship( "SandboxConfig", back_populates="organization", cascade="all, delete-orphan" ) diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 82451027..868f8a67 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -49,11 +49,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from") - @declared_attr - def file(cls) -> Mapped["FileMetadata"]: - """Relationship to file""" - return relationship("FileMetadata", back_populates="source_passages", lazy="selectin") - @declared_attr def organization(cls) -> Mapped["Organization"]: return relationship("Organization", back_populates="source_passages", lazy="selectin") @@ -74,11 +69,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): {"extend_existing": True}, ) - @declared_attr - def source(cls) -> Mapped["Source"]: - """Relationship to source""" - return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True) - class AgentPassage(BasePassage, AgentMixin): """Passages created by agents as archival memories""" diff --git a/letta/orm/source.py b/letta/orm/source.py index c4a0f2d9..f23c61e5 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,9 +1,8 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from sqlalchemy import JSON, Index, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column -from letta.orm import FileMetadata from letta.orm.custom_columns import EmbeddingConfigColumn from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase @@ -11,10 +10,7 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import Source as PydanticSource if TYPE_CHECKING: - from letta.orm.agent import Agent - from letta.orm.file import FileMetadata - from letta.orm.organization import Organization - from letta.orm.passage import SourcePassage + pass class Source(SqlalchemyBase, OrganizationMixin): @@ -34,16 +30,3 @@ class Source(SqlalchemyBase, OrganizationMixin): instructions: Mapped[str] = mapped_column(nullable=True, doc="instructions for how to use the source") embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.") metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.") - - # relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") - files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") - passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan") - agents: Mapped[List["Agent"]] = relationship( - "Agent", - secondary="sources_agents", - back_populates="sources", - lazy="selectin", - cascade="save-update", # Only propagate save and update operations - passive_deletes=True, # Let the database handle deletions - ) diff --git a/letta/schemas/file.py b/letta/schemas/file.py index 14e2a122..90132c50 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -85,6 +85,7 @@ class FileAgent(FileAgentBase): ) agent_id: str = Field(..., description="Unique identifier of the agent.") file_id: str = Field(..., description="Unique identifier of the file.") + source_id: str = Field(..., description="Unique identifier of the source (denormalized from files.source_id).") file_name: str = Field(..., description="Name of the file.") is_open: bool = Field(True, description="True if the agent currently has the file open.") visible_content: Optional[str] = Field( diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 44a8eb6f..953be77b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -107,6 +107,31 @@ class AgentManager: self.identity_manager = IdentityManager() self.file_agent_manager = FileAgentManager() + async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None: + """ + Validate that an agent exists and user has access to it using raw SQL for efficiency. + + Args: + session: Database session + agent_id: ID of the agent to validate + actor: User performing the action + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access + """ + agent_check_query = sa.text( + """ + SELECT 1 FROM agents + WHERE id = :agent_id + AND organization_id = :org_id + AND is_deleted = false + """ + ) + agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id}) + + if not agent_exists.fetchone(): + raise NoResultFound(f"Agent with ID {agent_id} not found") + @staticmethod def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]: """ @@ -1133,8 +1158,8 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def get_agent_by_id_async( self, agent_id: str, @@ -1853,28 +1878,8 @@ class AgentManager: # update agent in-context message IDs await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor) - @trace_method @enforce_types - def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: - """ - Lists all sources attached to an agent. - - Args: - agent_id: ID of the agent to list sources for - actor: User performing the action - - Returns: - List[str]: List of source IDs attached to the agent - """ - with db_registry.session() as session: - # Verify agent exists and user has permission to access it - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Use the lazy-loaded relationship to get sources - return [source.to_pydantic() for source in agent.sources] - @trace_method - @enforce_types async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: """ Lists all sources attached to an agent. @@ -1885,41 +1890,31 @@ class AgentManager: Returns: List[str]: List of source IDs attached to the agent + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access """ async with db_registry.async_session() as session: - # Verify agent exists and user has permission to access it - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + # Validate agent exists and user has access + await self._validate_agent_exists_async(session, agent_id, actor) - # Use the lazy-loaded relationship to get sources - return [source.to_pydantic() for source in agent.sources] + # Use raw SQL to efficiently fetch sources - much faster than lazy loading + # Fast query without relationship loading + query = ( + select(SourceModel) + .join(SourcesAgents, SourceModel.id == SourcesAgents.source_id) + .where( + SourcesAgents.agent_id == agent_id, + SourceModel.organization_id == actor.organization_id, + SourceModel.is_deleted == False, + ) + .order_by(SourceModel.created_at.desc(), SourceModel.id) + ) - @trace_method - @enforce_types - def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: - """ - Detaches a source from an agent. + result = await session.execute(query) + sources = result.scalars().all() - Args: - agent_id: ID of the agent to detach the source from - source_id: ID of the source to detach - actor: User performing the action - """ - with db_registry.session() as session: - # Verify agent exists and user has permission to access it - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Remove the source from the relationship - remaining_sources = [s for s in agent.sources if s.id != source_id] - - if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship - logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}") - - # Update the sources relationship - agent.sources = remaining_sources - - # Commit the changes - agent.update(session, actor=actor) - return agent.to_pydantic() + return [source.to_pydantic() for source in sources] @trace_method @enforce_types @@ -1931,22 +1926,29 @@ class AgentManager: agent_id: ID of the agent to detach the source from source_id: ID of the source to detach actor: User performing the action + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access """ async with db_registry.async_session() as session: - # Verify agent exists and user has permission to access it - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + # Validate agent exists and user has access + await self._validate_agent_exists_async(session, agent_id, actor) - # Remove the source from the relationship - remaining_sources = [s for s in agent.sources if s.id != source_id] + # Check if the source is actually attached to this agent using junction table + attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id) + attachment_result = await session.execute(attachment_check_query) + attachment = attachment_result.scalar_one_or_none() - if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship + if not attachment: logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}") + else: + # Delete the association directly from the junction table + delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id) + await session.execute(delete_query) + await session.commit() - # Update the sources relationship - agent.sources = remaining_sources - - # Commit the changes - await agent.update_async(session, actor=actor) + # Get agent without loading relationships for return value + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) return await agent.to_pydantic_async() # ====================================================================================================================== diff --git a/letta/services/file_processor/embedder/openai_embedder.py b/letta/services/file_processor/embedder/openai_embedder.py index 5a888549..b55ba936 100644 --- a/letta/services/file_processor/embedder/openai_embedder.py +++ b/letta/services/file_processor/embedder/openai_embedder.py @@ -25,7 +25,6 @@ class OpenAIEmbedder(BaseEmbedder): else EmbeddingConfig.default_config(model_name="letta") ) self.embedding_config = embedding_config or self.default_embedding_config - self.max_concurrent_requests = 20 # TODO: Unify to global OpenAI client self.client: OpenAIClient = cast( @@ -48,9 +47,55 @@ class OpenAIEmbedder(BaseEmbedder): "embedding_endpoint_type": self.embedding_config.embedding_endpoint_type, }, ) - embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config) - log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)}) - return [(idx, e) for idx, e in zip(batch_indices, embeddings)] + + try: + embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config) + log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)}) + return [(idx, e) for idx, e in zip(batch_indices, embeddings)] + except Exception as e: + # if it's a token limit error and we can split, do it + if self._is_token_limit_error(e) and len(batch) > 1: + logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying") + log_event( + "embedder.batch_split_retry", + { + "original_batch_size": len(batch), + "error": str(e), + "split_size": len(batch) // 2, + }, + ) + + # split batch in half + mid = len(batch) // 2 + batch1 = batch[:mid] + batch1_indices = batch_indices[:mid] + batch2 = batch[mid:] + batch2_indices = batch_indices[mid:] + + # retry with smaller batches + result1 = await self._embed_batch(batch1, batch1_indices) + result2 = await self._embed_batch(batch2, batch2_indices) + + return result1 + result2 + else: + # re-raise for other errors or if batch size is already 1 + raise + + def _is_token_limit_error(self, error: Exception) -> bool: + """Check if the error is due to token limit exceeded""" + # convert to string and check for token limit patterns + error_str = str(error).lower() + + # TODO: This is quite brittle, works for now + # check for the specific patterns we see in token limit errors + is_token_limit = ( + "max_tokens_per_request" in error_str + or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str) + or "token limit" in error_str + or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str) + ) + + return is_token_limit @trace_method async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: @@ -100,7 +145,7 @@ class OpenAIEmbedder(BaseEmbedder): log_event( "embedder.concurrent_processing_started", - {"concurrent_tasks": len(tasks), "max_concurrent_requests": self.max_concurrent_requests}, + {"concurrent_tasks": len(tasks)}, ) results = await asyncio.gather(*tasks) log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)}) diff --git a/letta/services/files_agents_manager.py b/letta/services/files_agents_manager.py index a4abab31..0264f8dc 100644 --- a/letta/services/files_agents_manager.py +++ b/letta/services/files_agents_manager.py @@ -29,6 +29,7 @@ class FileAgentManager: agent_id: str, file_id: str, file_name: str, + source_id: str, actor: PydanticUser, is_open: bool = True, visible_content: Optional[str] = None, @@ -47,7 +48,12 @@ class FileAgentManager: if is_open: # Use the efficient LRU + open method closed_files, was_already_open = await self.enforce_max_open_files_and_open( - agent_id=agent_id, file_id=file_id, file_name=file_name, actor=actor, visible_content=visible_content or "" + agent_id=agent_id, + file_id=file_id, + file_name=file_name, + source_id=source_id, + actor=actor, + visible_content=visible_content or "", ) # Get the updated file agent to return @@ -85,6 +91,7 @@ class FileAgentManager: agent_id=agent_id, file_id=file_id, file_name=file_name, + source_id=source_id, organization_id=actor.organization_id, is_open=is_open, visible_content=visible_content, @@ -327,7 +334,7 @@ class FileAgentManager: @enforce_types @trace_method async def enforce_max_open_files_and_open( - self, *, agent_id: str, file_id: str, file_name: str, actor: PydanticUser, visible_content: str + self, *, agent_id: str, file_id: str, file_name: str, source_id: str, actor: PydanticUser, visible_content: str ) -> tuple[List[str], bool]: """ Efficiently handle LRU eviction and file opening in a single transaction. @@ -336,6 +343,7 @@ class FileAgentManager: agent_id: ID of the agent file_id: ID of the file to open file_name: Name of the file to open + source_id: ID of the source (denormalized from files.source_id) actor: User performing the action visible_content: Content to set for the opened file @@ -418,6 +426,7 @@ class FileAgentManager: agent_id=agent_id, file_id=file_id, file_name=file_name, + source_id=source_id, organization_id=actor.organization_id, is_open=True, visible_content=visible_content, @@ -516,6 +525,7 @@ class FileAgentManager: agent_id=agent_id, file_id=meta.id, file_name=meta.file_name, + source_id=meta.source_id, organization_id=actor.organization_id, is_open=is_now_open, visible_content=vc, diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 9e3ee4d2..b3cd2c04 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,8 +1,12 @@ import asyncio from typing import List, Optional +from sqlalchemy import select + +from letta.orm import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.source import Source as SourceModel +from letta.orm.sources_agents import SourcesAgents from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.source import Source as PydanticSource @@ -104,9 +108,21 @@ class SourceManager: # Verify source exists and user has permission to access it source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) - # The agents relationship is already loaded due to lazy="selectin" in the Source model - # and will be properly filtered by organization_id due to the OrganizationMixin - agents_orm = source.agents + # Use junction table query instead of relationship to avoid performance issues + query = ( + select(AgentModel) + .join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id) + .where( + SourcesAgents.source_id == source_id, + AgentModel.organization_id == actor.organization_id if actor else True, + AgentModel.is_deleted == False, + ) + .order_by(AgentModel.created_at.desc(), AgentModel.id) + ) + + result = await session.execute(query) + agents_orm = result.scalars().all() + return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) # TODO: We make actor optional for now, but should most likely be enforced due to security reasons diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index 4815243a..b56b2253 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -180,7 +180,12 @@ class LettaFileToolExecutor(ToolExecutor): # Handle LRU eviction and file opening closed_files, was_already_open = await self.files_agents_manager.enforce_max_open_files_and_open( - agent_id=agent_state.id, file_id=file_id, file_name=file_name, actor=self.actor, visible_content=visible_content + agent_id=agent_state.id, + file_id=file_id, + file_name=file_name, + source_id=file.source_id, + actor=self.actor, + visible_content=visible_content, ) opened_files.append(file_name) diff --git a/tests/test_file_processor.py b/tests/test_file_processor.py new file mode 100644 index 00000000..e2448e2e --- /dev/null +++ b/tests/test_file_processor.py @@ -0,0 +1,219 @@ +from unittest.mock import AsyncMock, Mock, patch + +import openai +import pytest + +from letta.errors import ErrorCode, LLMBadRequestError +from letta.schemas.embedding_config import EmbeddingConfig +from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder + + +class TestOpenAIEmbedder: + """Test suite for OpenAI embedder functionality""" + + @pytest.fixture + def mock_user(self): + """Create a mock user for testing""" + user = Mock() + user.organization_id = "test_org_id" + return user + + @pytest.fixture + def embedding_config(self): + """Create a test embedding config""" + return EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=3, # small dimension for testing + embedding_chunk_size=300, + batch_size=2, # small batch size for testing + ) + + @pytest.fixture + def embedder(self, embedding_config): + """Create OpenAI embedder with test config""" + with patch("letta.services.file_processor.embedder.openai_embedder.LLMClient.create") as mock_create: + mock_client = Mock() + mock_client.handle_llm_error = Mock() + mock_create.return_value = mock_client + + embedder = OpenAIEmbedder(embedding_config) + embedder.client = mock_client + return embedder + + @pytest.mark.asyncio + async def test_successful_embedding_generation(self, embedder, mock_user): + """Test successful embedding generation for normal cases""" + # mock successful embedding response + mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + embedder.client.request_embeddings = AsyncMock(return_value=mock_embeddings) + + chunks = ["chunk 1", "chunk 2"] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert len(passages) == 2 + assert passages[0].text == "chunk 1" + assert passages[1].text == "chunk 2" + # embeddings are padded to MAX_EMBEDDING_DIM, so check first 3 values + assert passages[0].embedding[:3] == [0.1, 0.2, 0.3] + assert passages[1].embedding[:3] == [0.4, 0.5, 0.6] + assert passages[0].file_id == file_id + assert passages[0].source_id == source_id + + @pytest.mark.asyncio + async def test_token_limit_retry_splits_batch(self, embedder, mock_user): + """Test that token limit errors trigger batch splitting and retry""" + # create a mock token limit error + mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}} + token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body) + + # first call fails with token limit, subsequent calls succeed + call_count = 0 + + async def mock_request_embeddings(inputs, embedding_config): + nonlocal call_count + call_count += 1 + if call_count == 1 and len(inputs) == 4: # first call with full batch + raise token_limit_error + elif len(inputs) == 2: # split batches succeed + return [[0.1, 0.2], [0.3, 0.4]] if call_count == 2 else [[0.5, 0.6], [0.7, 0.8]] + else: + return [[0.1, 0.2]] * len(inputs) + + embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings) + + chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + # should still get all 4 passages despite the retry + assert len(passages) == 4 + assert all(len(p.embedding) == 4096 for p in passages) # padded to MAX_EMBEDDING_DIM + # verify multiple calls were made (original + retries) + assert call_count >= 2 + + @pytest.mark.asyncio + async def test_token_limit_error_detection(self, embedder): + """Test various token limit error detection patterns""" + # test openai BadRequestError with proper structure + mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}} + openai_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body) + assert embedder._is_token_limit_error(openai_error) is True + + # test error with message but no code + mock_error_body_no_code = {"error": {"message": "max_tokens_per_request exceeded"}} + openai_error_no_code = openai.BadRequestError( + message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body_no_code + ) + assert embedder._is_token_limit_error(openai_error_no_code) is True + + # test fallback string detection + generic_error = Exception("Requested 100000 tokens, max 50000 tokens per request") + assert embedder._is_token_limit_error(generic_error) is True + + # test non-token errors + other_error = Exception("Some other error") + assert embedder._is_token_limit_error(other_error) is False + + auth_error = openai.AuthenticationError( + message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}} + ) + assert embedder._is_token_limit_error(auth_error) is False + + @pytest.mark.asyncio + async def test_non_token_error_handling(self, embedder, mock_user): + """Test that non-token errors are properly handled and re-raised""" + # create a non-token error + auth_error = openai.AuthenticationError( + message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}} + ) + + # mock handle_llm_error to return a standardized error + handled_error = LLMBadRequestError(message="Handled error", code=ErrorCode.UNAUTHENTICATED) + embedder.client.handle_llm_error.return_value = handled_error + embedder.client.request_embeddings = AsyncMock(side_effect=auth_error) + + chunks = ["chunk 1"] + file_id = "test_file" + source_id = "test_source" + + with pytest.raises(LLMBadRequestError) as exc_info: + await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert exc_info.value == handled_error + embedder.client.handle_llm_error.assert_called_once_with(auth_error) + + @pytest.mark.asyncio + async def test_single_item_batch_no_retry(self, embedder, mock_user): + """Test that single-item batches don't retry on token limit errors""" + # create a token limit error + mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}} + token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body) + + handled_error = LLMBadRequestError(message="Handled token limit error", code=ErrorCode.INVALID_ARGUMENT) + embedder.client.handle_llm_error.return_value = handled_error + embedder.client.request_embeddings = AsyncMock(side_effect=token_limit_error) + + chunks = ["very long chunk that exceeds token limit"] + file_id = "test_file" + source_id = "test_source" + + with pytest.raises(LLMBadRequestError) as exc_info: + await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert exc_info.value == handled_error + embedder.client.handle_llm_error.assert_called_once_with(token_limit_error) + + @pytest.mark.asyncio + async def test_empty_chunks_handling(self, embedder, mock_user): + """Test handling of empty chunks list""" + chunks = [] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert passages == [] + # should not call request_embeddings for empty input + embedder.client.request_embeddings.assert_not_called() + + @pytest.mark.asyncio + async def test_embedding_order_preservation(self, embedder, mock_user): + """Test that embedding order is preserved even with retries""" + # set up embedder to split batches (batch_size=2) + embedder.embedding_config.batch_size = 2 + + # mock responses for each batch + async def mock_request_embeddings(inputs, embedding_config): + # return embeddings that correspond to input order + if inputs == ["chunk 1", "chunk 2"]: + return [[0.1, 0.1], [0.2, 0.2]] + elif inputs == ["chunk 3", "chunk 4"]: + return [[0.3, 0.3], [0.4, 0.4]] + else: + return [[0.1, 0.1]] * len(inputs) + + embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings) + + chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + # verify order is preserved + assert len(passages) == 4 + assert passages[0].text == "chunk 1" + assert passages[0].embedding[:2] == [0.1, 0.1] # check first 2 values before padding + assert passages[1].text == "chunk 2" + assert passages[1].embedding[:2] == [0.2, 0.2] + assert passages[2].text == "chunk 3" + assert passages[2].embedding[:2] == [0.3, 0.3] + assert passages[3].text == "chunk 4" + assert passages[3].embedding[:2] == [0.4, 0.4] diff --git a/tests/test_managers.py b/tests/test_managers.py index 2048533e..f8833f8f 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -673,6 +673,7 @@ async def file_attachment(server, default_user, sarah_agent, default_file): agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="initial", ) @@ -903,6 +904,7 @@ async def test_get_context_window_basic( agent_id=created_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="hello", ) @@ -7221,6 +7223,7 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="hello", ) @@ -7243,6 +7246,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="first", ) @@ -7252,6 +7256,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, is_open=False, visible_content="second", @@ -7326,15 +7331,28 @@ async def test_list_files_and_agents( ): # default_file ↔ charles (open) await server.file_agent_manager.attach_file( - agent_id=charles_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user + agent_id=charles_agent.id, + file_id=default_file.id, + file_name=default_file.file_name, + source_id=default_file.source_id, + actor=default_user, ) # default_file ↔ sarah (open) await server.file_agent_manager.attach_file( - agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user + agent_id=sarah_agent.id, + file_id=default_file.id, + file_name=default_file.file_name, + source_id=default_file.source_id, + actor=default_user, ) # another_file ↔ sarah (closed) await server.file_agent_manager.attach_file( - agent_id=sarah_agent.id, file_id=another_file.id, file_name=another_file.file_name, actor=default_user, is_open=False + agent_id=sarah_agent.id, + file_id=another_file.id, + file_name=another_file.file_name, + source_id=another_file.source_id, + actor=default_user, + is_open=False, ) files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user) @@ -7384,6 +7402,7 @@ async def test_org_scoping( agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, ) @@ -7420,6 +7439,7 @@ async def test_mark_access_bulk(server, default_user, sarah_agent, default_sourc agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"content for {file.file_name}", ) @@ -7478,6 +7498,7 @@ async def test_lru_eviction_on_attach(server, default_user, sarah_agent, default agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"content for {file.file_name}", ) @@ -7530,6 +7551,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa agent_id=sarah_agent.id, file_id=files[i].id, file_name=files[i].file_name, + source_id=files[i].source_id, actor=default_user, visible_content=f"content for {files[i].file_name}", ) @@ -7539,6 +7561,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, + source_id=files[-1].source_id, actor=default_user, is_open=False, visible_content=f"content for {files[-1].file_name}", @@ -7555,7 +7578,12 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa # Now "open" the last file using the efficient method closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open( - agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content" + agent_id=sarah_agent.id, + file_id=files[-1].id, + file_name=files[-1].file_name, + source_id=files[-1].source_id, + actor=default_user, + visible_content="updated content", ) # Should have closed 1 file (the oldest one) @@ -7603,6 +7631,7 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"content for {file.file_name}", ) @@ -7617,7 +7646,12 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa # "Reopen" the last file (which is already open) closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open( - agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content" + agent_id=sarah_agent.id, + file_id=files[-1].id, + file_name=files[-1].file_name, + source_id=files[-1].source_id, + actor=default_user, + visible_content="updated content", ) # Should not have closed any files since we're within the limit @@ -7645,7 +7679,12 @@ async def test_last_accessed_at_updates_correctly(server, default_user, sarah_ag file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text="test content") file_agent, closed_files = await server.file_agent_manager.attach_file( - agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, actor=default_user, visible_content="initial content" + agent_id=sarah_agent.id, + file_id=file.id, + file_name=file.file_name, + source_id=file.source_id, + actor=default_user, + visible_content="initial content", ) initial_time = file_agent.last_accessed_at @@ -7777,6 +7816,7 @@ async def test_attach_files_bulk_lru_eviction(server, default_user, sarah_agent, agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"existing content {i}", ) @@ -7842,6 +7882,7 @@ async def test_attach_files_bulk_mixed_existing_new(server, default_user, sarah_ agent_id=sarah_agent.id, file_id=existing_file.id, file_name=existing_file.file_name, + source_id=existing_file.source_id, actor=default_user, visible_content="old content", is_open=False, # Start as closed