From 33eaabb04a885e8da736219dff834e64c56a8bde Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 14 Jul 2025 11:03:15 -0700 Subject: [PATCH] chore: bump version 0.8.14 (#2720) Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Sarah Wooders Co-authored-by: Matthew Zhou Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: jnjpng Co-authored-by: Jin Peng Co-authored-by: cpacker Co-authored-by: Shubham Naik Co-authored-by: Shubham Naik Co-authored-by: Kevin Lin --- ...rite_source_id_directly_to_files_agents.py | 52 +++ letta/__init__.py | 2 +- letta/constants.py | 6 + letta/functions/function_sets/base.py | 4 +- letta/helpers/pinecone_utils.py | 175 ++++++++- letta/orm/file.py | 19 +- letta/orm/files_agents.py | 19 +- letta/orm/organization.py | 4 - letta/orm/passage.py | 10 - letta/orm/source.py | 23 +- letta/schemas/file.py | 1 + letta/schemas/memory.py | 4 +- letta/server/rest_api/routers/v1/agents.py | 8 +- letta/server/rest_api/routers/v1/messages.py | 8 +- letta/server/rest_api/routers/v1/sources.py | 6 +- letta/server/server.py | 3 - letta/services/agent_manager.py | 343 ++++++++++-------- letta/services/block_manager.py | 36 +- .../context_window_calculator.py | 25 +- .../token_counter.py | 40 ++ .../file_processor/chunker/line_chunker.py | 17 + .../embedder/openai_embedder.py | 55 ++- letta/services/files_agents_manager.py | 14 +- letta/services/group_manager.py | 22 +- letta/services/source_manager.py | 22 +- .../tool_executor/core_tool_executor.py | 4 +- .../tool_executor/files_tool_executor.py | 7 +- poetry.lock | 6 +- pyproject.toml | 3 +- test_agent_serialization.json | 4 +- tests/helpers/utils.py | 2 +- tests/test_file_processor.py | 219 +++++++++++ tests/test_managers.py | 53 ++- tests/test_sdk_client.py | 33 ++ tests/test_utils.py | 5 +- 35 files changed, 944 insertions(+), 310 deletions(-) create mode 100644 alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py create mode 100644 tests/test_file_processor.py diff --git a/alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py b/alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py new file mode 100644 index 00000000..9319e99c --- /dev/null +++ b/alembic/versions/495f3f474131_write_source_id_directly_to_files_agents.py @@ -0,0 +1,52 @@ +"""Write source_id directly to files agents + +Revision ID: 495f3f474131 +Revises: 47d2277e530d +Create Date: 2025-07-10 17:14:45.154738 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "495f3f474131" +down_revision: Union[str, None] = "47d2277e530d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Step 1: Add the column as nullable first + op.add_column("files_agents", sa.Column("source_id", sa.String(), nullable=True)) + + # Step 2: Backfill source_id from files table + connection = op.get_bind() + connection.execute( + sa.text( + """ + UPDATE files_agents + SET source_id = files.source_id + FROM files + WHERE files_agents.file_id = files.id + """ + ) + ) + + # Step 3: Make the column NOT NULL now that it's populated + op.alter_column("files_agents", "source_id", nullable=False) + + # Step 4: Add the foreign key constraint + op.create_foreign_key(None, "files_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "files_agents", type_="foreignkey") + op.drop_column("files_agents", "source_id") + # ### end Alembic commands ### diff --git a/letta/__init__.py b/letta/__init__.py index 3aae91bf..83fc346e 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -5,7 +5,7 @@ try: __version__ = version("letta") except PackageNotFoundError: # Fallback for development installations - __version__ = "0.8.13" + __version__ = "0.8.14" if os.environ.get("LETTA_VERSION"): __version__ = os.environ["LETTA_VERSION"] diff --git a/letta/constants.py b/letta/constants.py index 5b7725bb..05711ab3 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -372,3 +372,9 @@ PINECONE_METRIC = "cosine" PINECONE_CLOUD = "aws" PINECONE_REGION = "us-east-1" PINECONE_MAX_BATCH_SIZE = 96 + +# retry configuration +PINECONE_MAX_RETRY_ATTEMPTS = 3 +PINECONE_RETRY_BASE_DELAY = 1.0 # seconds +PINECONE_RETRY_MAX_DELAY = 60.0 # seconds +PINECONE_RETRY_BACKOFF_FACTOR = 2.0 diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index f70ecb31..ad91fb1f 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -135,7 +135,7 @@ def core_memory_append(agent_state: "AgentState", label: str, content: str) -> O Append to the contents of core memory. Args: - label (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited. content (str): Content to write to the memory. All unicode (including emojis) are supported. Returns: @@ -152,7 +152,7 @@ def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, Replace the contents of core memory. To delete memories, use an empty string for new_content. Args: - label (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited. old_content (str): String to replace. Must be an exact match. new_content (str): Content to write to the memory. All unicode (including emojis) are supported. diff --git a/letta/helpers/pinecone_utils.py b/letta/helpers/pinecone_utils.py index f583b933..7caa5280 100644 --- a/letta/helpers/pinecone_utils.py +++ b/letta/helpers/pinecone_utils.py @@ -1,8 +1,20 @@ +import asyncio +import random +import time +from functools import wraps from typing import Any, Dict, List +from letta.otel.tracing import trace_method + try: from pinecone import IndexEmbed, PineconeAsyncio - from pinecone.exceptions.exceptions import NotFoundException + from pinecone.exceptions.exceptions import ( + ForbiddenException, + NotFoundException, + PineconeApiException, + ServiceException, + UnauthorizedException, + ) PINECONE_AVAILABLE = True except ImportError: @@ -12,8 +24,12 @@ from letta.constants import ( PINECONE_CLOUD, PINECONE_EMBEDDING_MODEL, PINECONE_MAX_BATCH_SIZE, + PINECONE_MAX_RETRY_ATTEMPTS, PINECONE_METRIC, PINECONE_REGION, + PINECONE_RETRY_BACKOFF_FACTOR, + PINECONE_RETRY_BASE_DELAY, + PINECONE_RETRY_MAX_DELAY, PINECONE_TEXT_FIELD_NAME, ) from letta.log import get_logger @@ -23,6 +39,87 @@ from letta.settings import settings logger = get_logger(__name__) +def pinecone_retry( + max_attempts: int = PINECONE_MAX_RETRY_ATTEMPTS, + base_delay: float = PINECONE_RETRY_BASE_DELAY, + max_delay: float = PINECONE_RETRY_MAX_DELAY, + backoff_factor: float = PINECONE_RETRY_BACKOFF_FACTOR, +): + """ + Decorator to retry Pinecone operations with exponential backoff. + + Args: + max_attempts: Maximum number of retry attempts + base_delay: Base delay in seconds for the first retry + max_delay: Maximum delay in seconds between retries + backoff_factor: Factor to increase delay after each failed attempt + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + operation_name = func.__name__ + start_time = time.time() + + for attempt in range(max_attempts): + try: + logger.debug(f"[Pinecone] Starting {operation_name} (attempt {attempt + 1}/{max_attempts})") + result = await func(*args, **kwargs) + + execution_time = time.time() - start_time + logger.info(f"[Pinecone] {operation_name} completed successfully in {execution_time:.2f}s") + return result + + except (ServiceException, PineconeApiException) as e: + # retryable server errors + if attempt == max_attempts - 1: + execution_time = time.time() - start_time + logger.error(f"[Pinecone] {operation_name} failed after {max_attempts} attempts in {execution_time:.2f}s: {str(e)}") + raise + + # calculate delay with exponential backoff and jitter + delay = min(base_delay * (backoff_factor**attempt), max_delay) + jitter = random.uniform(0, delay * 0.1) # add up to 10% jitter + total_delay = delay + jitter + + logger.warning( + f"[Pinecone] {operation_name} failed (attempt {attempt + 1}/{max_attempts}): {str(e)}. Retrying in {total_delay:.2f}s" + ) + await asyncio.sleep(total_delay) + + except (UnauthorizedException, ForbiddenException) as e: + # non-retryable auth errors + execution_time = time.time() - start_time + logger.error(f"[Pinecone] {operation_name} failed with auth error in {execution_time:.2f}s: {str(e)}") + raise + + except NotFoundException as e: + # non-retryable not found errors + execution_time = time.time() - start_time + logger.warning(f"[Pinecone] {operation_name} failed with not found error in {execution_time:.2f}s: {str(e)}") + raise + + except Exception as e: + # other unexpected errors - retry once then fail + if attempt == max_attempts - 1: + execution_time = time.time() - start_time + logger.error(f"[Pinecone] {operation_name} failed after {max_attempts} attempts in {execution_time:.2f}s: {str(e)}") + raise + + delay = min(base_delay * (backoff_factor**attempt), max_delay) + jitter = random.uniform(0, delay * 0.1) + total_delay = delay + jitter + + logger.warning( + f"[Pinecone] {operation_name} failed with unexpected error (attempt {attempt + 1}/{max_attempts}): {str(e)}. Retrying in {total_delay:.2f}s" + ) + await asyncio.sleep(total_delay) + + return wrapper + + return decorator + + def should_use_pinecone(verbose: bool = False): if verbose: logger.info( @@ -44,29 +141,42 @@ def should_use_pinecone(verbose: bool = False): ) +@pinecone_retry() +@trace_method async def upsert_pinecone_indices(): if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") - for index_name in get_pinecone_indices(): + indices = get_pinecone_indices() + logger.info(f"[Pinecone] Upserting {len(indices)} indices: {indices}") + + for index_name in indices: async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: if not await pc.has_index(index_name): + logger.info(f"[Pinecone] Creating index {index_name} with model {PINECONE_EMBEDDING_MODEL}") await pc.create_index_for_model( name=index_name, cloud=PINECONE_CLOUD, region=PINECONE_REGION, embed=IndexEmbed(model=PINECONE_EMBEDDING_MODEL, field_map={"text": PINECONE_TEXT_FIELD_NAME}, metric=PINECONE_METRIC), ) + logger.info(f"[Pinecone] Successfully created index {index_name}") + else: + logger.debug(f"[Pinecone] Index {index_name} already exists") def get_pinecone_indices() -> List[str]: return [settings.pinecone_agent_index, settings.pinecone_source_index] +@pinecone_retry() +@trace_method async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User): if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") + logger.info(f"[Pinecone] Preparing to upsert {len(chunks)} chunks for file {file_id} source {source_id}") + records = [] for i, chunk in enumerate(chunks): record = { @@ -77,14 +187,19 @@ async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, ch } records.append(record) + logger.debug(f"[Pinecone] Created {len(records)} records for file {file_id}") return await upsert_records_to_pinecone_index(records, actor) +@pinecone_retry() +@trace_method async def delete_file_records_from_pinecone_index(file_id: str, actor: User): if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") namespace = actor.organization_id + logger.info(f"[Pinecone] Deleting records for file {file_id} from index {settings.pinecone_source_index} namespace {namespace}") + try: async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) @@ -95,48 +210,72 @@ async def delete_file_records_from_pinecone_index(file_id: str, actor: User): }, namespace=namespace, ) + logger.info(f"[Pinecone] Successfully deleted records for file {file_id}") except NotFoundException: - logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}") + logger.warning(f"[Pinecone] Namespace {namespace} not found for file {file_id} and org {actor.organization_id}") +@pinecone_retry() +@trace_method async def delete_source_records_from_pinecone_index(source_id: str, actor: User): if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") namespace = actor.organization_id + logger.info(f"[Pinecone] Deleting records for source {source_id} from index {settings.pinecone_source_index} namespace {namespace}") + try: async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) async with pc.IndexAsyncio(host=description.index.host) as dense_index: await dense_index.delete(filter={"source_id": {"$eq": source_id}}, namespace=namespace) + logger.info(f"[Pinecone] Successfully deleted records for source {source_id}") except NotFoundException: - logger.warning(f"Pinecone namespace {namespace} not found for {source_id} and {actor.organization_id}") + logger.warning(f"[Pinecone] Namespace {namespace} not found for source {source_id} and org {actor.organization_id}") +@pinecone_retry() +@trace_method async def upsert_records_to_pinecone_index(records: List[dict], actor: User): if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") + logger.info(f"[Pinecone] Upserting {len(records)} records to index {settings.pinecone_source_index} for org {actor.organization_id}") + async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) async with pc.IndexAsyncio(host=description.index.host) as dense_index: - # Process records in batches to avoid exceeding Pinecone limits + # process records in batches to avoid exceeding pinecone limits + total_batches = (len(records) + PINECONE_MAX_BATCH_SIZE - 1) // PINECONE_MAX_BATCH_SIZE + logger.debug(f"[Pinecone] Processing {total_batches} batches of max {PINECONE_MAX_BATCH_SIZE} records each") + for i in range(0, len(records), PINECONE_MAX_BATCH_SIZE): batch = records[i : i + PINECONE_MAX_BATCH_SIZE] + batch_num = (i // PINECONE_MAX_BATCH_SIZE) + 1 + + logger.debug(f"[Pinecone] Upserting batch {batch_num}/{total_batches} with {len(batch)} records") await dense_index.upsert_records(actor.organization_id, batch) + logger.info(f"[Pinecone] Successfully upserted all {len(records)} records in {total_batches} batches") + +@pinecone_retry() +@trace_method async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]: if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") + namespace = actor.organization_id + logger.info( + f"[Pinecone] Searching index {settings.pinecone_source_index} namespace {namespace} with query length {len(query)} chars, limit {limit}" + ) + logger.debug(f"[Pinecone] Search filter: {filter}") + async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) async with pc.IndexAsyncio(host=description.index.host) as dense_index: - namespace = actor.organization_id - try: - # Search the dense index with reranking + # search the dense index with reranking search_results = await dense_index.search( namespace=namespace, query={ @@ -146,17 +285,26 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], }, rerank={"model": "bge-reranker-v2-m3", "top_n": limit, "rank_fields": [PINECONE_TEXT_FIELD_NAME]}, ) + + result_count = len(search_results.get("matches", [])) + logger.info(f"[Pinecone] Search completed, found {result_count} matches") return search_results + except Exception as e: - logger.warning(f"Failed to search Pinecone namespace {actor.organization_id}: {str(e)}") + logger.warning(f"[Pinecone] Failed to search namespace {namespace}: {str(e)}") raise e +@pinecone_retry() +@trace_method async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]: if not PINECONE_AVAILABLE: raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") namespace = actor.organization_id + logger.info(f"[Pinecone] Listing records for file {file_id} from index {settings.pinecone_source_index} namespace {namespace}") + logger.debug(f"[Pinecone] List params - limit: {limit}, pagination_token: {pagination_token}") + try: async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) @@ -172,9 +320,14 @@ async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = result = [] async for ids in dense_index.list(**kwargs): result.extend(ids) + + logger.info(f"[Pinecone] Successfully listed {len(result)} records for file {file_id}") return result + except Exception as e: - logger.warning(f"Failed to list Pinecone namespace {actor.organization_id}: {str(e)}") + logger.warning(f"[Pinecone] Failed to list records for file {file_id} in namespace {namespace}: {str(e)}") raise e + except NotFoundException: - logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}") + logger.warning(f"[Pinecone] Namespace {namespace} not found for file {file_id} and org {actor.organization_id}") + return [] diff --git a/letta/orm/file.py b/letta/orm/file.py index 885731e5..8cae2448 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -1,5 +1,5 @@ import uuid -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from sqlalchemy import ForeignKey, Index, Integer, String, Text, UniqueConstraint, desc from sqlalchemy.ext.asyncio import AsyncAttrs @@ -11,10 +11,7 @@ from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata as PydanticFileMetadata if TYPE_CHECKING: - from letta.orm.files_agents import FileAgent - from letta.orm.organization import Organization - from letta.orm.passage import SourcePassage - from letta.orm.source import Source + pass # TODO: Note that this is NOT organization scoped, this is potentially dangerous if we misuse this @@ -64,18 +61,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.") # relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") - source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") - source_passages: Mapped[List["SourcePassage"]] = relationship( - "SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan" - ) - file_agents: Mapped[List["FileAgent"]] = relationship( - "FileAgent", - back_populates="file", - lazy="selectin", - cascade="all, delete-orphan", - passive_deletes=True, # ← add this - ) content: Mapped[Optional["FileContent"]] = relationship( "FileContent", uselist=False, diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py index f7398a91..d8fd5c2f 100644 --- a/letta/orm/files_agents.py +++ b/letta/orm/files_agents.py @@ -12,7 +12,7 @@ from letta.schemas.block import Block as PydanticBlock from letta.schemas.file import FileAgent as PydanticFileAgent if TYPE_CHECKING: - from letta.orm.file import FileMetadata + pass class FileAgent(SqlalchemyBase, OrganizationMixin): @@ -55,6 +55,12 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): nullable=False, doc="ID of the agent", ) + source_id: Mapped[str] = mapped_column( + String, + ForeignKey("sources.id", ondelete="CASCADE"), + nullable=False, + doc="ID of the source (denormalized from files.source_id)", + ) file_name: Mapped[str] = mapped_column( String, @@ -78,13 +84,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): back_populates="file_agents", lazy="selectin", ) - file: Mapped["FileMetadata"] = relationship( - "FileMetadata", - foreign_keys=[file_id], - lazy="selectin", - back_populates="file_agents", - passive_deletes=True, # ← add this - ) # TODO: This is temporary as we figure out if we want FileBlock as a first class citizen def to_pydantic_block(self) -> PydanticBlock: @@ -99,8 +98,8 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): return PydanticBlock( organization_id=self.organization_id, value=visible_content, - label=self.file.file_name, + label=self.file_name, # use denormalized file_name instead of self.file.file_name read_only=True, - metadata={"source_id": self.file.source_id}, + metadata={"source_id": self.source_id}, # use denormalized source_id limit=CORE_MEMORY_SOURCE_CHAR_LIMIT, ) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index e1937633..f5f65cb9 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from letta.orm.agent import Agent from letta.orm.agent_passage import AgentPassage from letta.orm.block import Block - from letta.orm.file import FileMetadata from letta.orm.group import Group from letta.orm.identity import Identity from letta.orm.llm_batch_item import LLMBatchItem @@ -18,7 +17,6 @@ if TYPE_CHECKING: from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable - from letta.orm.source import Source from letta.orm.source_passage import SourcePassage from letta.orm.tool import Tool from letta.orm.user import User @@ -38,8 +36,6 @@ class Organization(SqlalchemyBase): tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") # mcp_servers: Mapped[List["MCPServer"]] = relationship("MCPServer", back_populates="organization", cascade="all, delete-orphan") blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan") - sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") - files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan") sandbox_configs: Mapped[List["SandboxConfig"]] = relationship( "SandboxConfig", back_populates="organization", cascade="all, delete-orphan" ) diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 82451027..868f8a67 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -49,11 +49,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from") - @declared_attr - def file(cls) -> Mapped["FileMetadata"]: - """Relationship to file""" - return relationship("FileMetadata", back_populates="source_passages", lazy="selectin") - @declared_attr def organization(cls) -> Mapped["Organization"]: return relationship("Organization", back_populates="source_passages", lazy="selectin") @@ -74,11 +69,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): {"extend_existing": True}, ) - @declared_attr - def source(cls) -> Mapped["Source"]: - """Relationship to source""" - return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True) - class AgentPassage(BasePassage, AgentMixin): """Passages created by agents as archival memories""" diff --git a/letta/orm/source.py b/letta/orm/source.py index c4a0f2d9..f23c61e5 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,9 +1,8 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from sqlalchemy import JSON, Index, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column -from letta.orm import FileMetadata from letta.orm.custom_columns import EmbeddingConfigColumn from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase @@ -11,10 +10,7 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import Source as PydanticSource if TYPE_CHECKING: - from letta.orm.agent import Agent - from letta.orm.file import FileMetadata - from letta.orm.organization import Organization - from letta.orm.passage import SourcePassage + pass class Source(SqlalchemyBase, OrganizationMixin): @@ -34,16 +30,3 @@ class Source(SqlalchemyBase, OrganizationMixin): instructions: Mapped[str] = mapped_column(nullable=True, doc="instructions for how to use the source") embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.") metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.") - - # relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") - files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") - passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan") - agents: Mapped[List["Agent"]] = relationship( - "Agent", - secondary="sources_agents", - back_populates="sources", - lazy="selectin", - cascade="save-update", # Only propagate save and update operations - passive_deletes=True, # Let the database handle deletions - ) diff --git a/letta/schemas/file.py b/letta/schemas/file.py index 14e2a122..90132c50 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -85,6 +85,7 @@ class FileAgent(FileAgentBase): ) agent_id: str = Field(..., description="Unique identifier of the agent.") file_id: str = Field(..., description="Unique identifier of the file.") + source_id: str = Field(..., description="Unique identifier of the source (denormalized from files.source_id).") file_name: str = Field(..., description="Name of the file.") is_open: bool = Field(True, description="True if the agent currently has the file open.") visible_content: Optional[str] = Field( diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 97658393..eac33ae7 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -210,7 +210,7 @@ class BasicBlockMemory(Memory): Append to the contents of core memory. Args: - label (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited. content (str): Content to write to the memory. All unicode (including emojis) are supported. Returns: @@ -226,7 +226,7 @@ class BasicBlockMemory(Memory): Replace the contents of core memory. To delete memories, use an empty string for new_content. Args: - label (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited. old_content (str): String to replace. Must be an exact match. new_content (str): Content to write to the memory. All unicode (including emojis) are supported. diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index fee2de03..78355f66 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -272,14 +272,14 @@ async def modify_agent( @router.get("/{agent_id}/tools", response_model=list[Tool], operation_id="list_agent_tools") -def list_agent_tools( +async def list_agent_tools( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Get tools from an existing agent""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.list_attached_tools_async(agent_id=agent_id, actor=actor) @router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool") @@ -1072,7 +1072,7 @@ async def _process_message_background( completed_at=datetime.now(timezone.utc), metadata={"error": str(e)}, ) - await server.job_manager.update_job_by_id_async(job_id=job_id, job_update=job_update, actor=actor) + await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor) @router.post( diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index e156d05d..28fcd185 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import APIRouter, Body, Depends, Header, Query, status +from fastapi import APIRouter, Body, Depends, Header, Query from fastapi.exceptions import HTTPException from starlette.requests import Request @@ -45,12 +45,8 @@ async def create_messages_batch( if length > max_bytes: raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.") - # Reject request if env var is not set if not settings.enable_batch_job_polling: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.", - ) + logger.warning("Batch job polling is disabled. Enable batch processing by setting LETTA_ENABLE_BATCH_JOB_POLLING to True.") actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) batch_job = BatchJob( diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index e7ab5370..3997d578 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -391,8 +391,8 @@ async def get_file_metadata( if file_metadata.source_id != source_id: raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.") - if should_use_pinecone() and not file_metadata.is_processing_terminal(): - ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks) + if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING: + ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor) logger.info( f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} ({file_metadata.file_name}) in organization {actor.organization_id}" ) @@ -402,7 +402,7 @@ async def get_file_metadata( file_status = file_metadata.processing_status else: file_status = FileProcessingStatus.COMPLETED - await server.file_manager.update_file_status( + file_metadata = await server.file_manager.update_file_status( file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status ) diff --git a/letta/server/server.py b/letta/server/server.py index 8a2c38fe..af996ffc 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1342,9 +1342,6 @@ class SyncServer(Server): new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added - # rebuild system prompt and force - agent_state = await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) - # update job status job.status = JobStatus.completed job.metadata["num_passages"] = num_passages diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 2bc79280..d9253a9a 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -107,6 +107,32 @@ class AgentManager: self.identity_manager = IdentityManager() self.file_agent_manager = FileAgentManager() + @trace_method + async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None: + """ + Validate that an agent exists and user has access to it using raw SQL for efficiency. + + Args: + session: Database session + agent_id: ID of the agent to validate + actor: User performing the action + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access + """ + agent_check_query = sa.text( + """ + SELECT 1 FROM agents + WHERE id = :agent_id + AND organization_id = :org_id + AND is_deleted = false + """ + ) + agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id}) + + if not agent_exists.fetchone(): + raise NoResultFound(f"Agent with ID {agent_id} not found") + @staticmethod def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]: """ @@ -635,24 +661,24 @@ class AgentManager: return init_messages - @trace_method @enforce_types + @trace_method def append_initial_message_sequence_to_in_context_messages( self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None ) -> PydanticAgentState: init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor) - @trace_method @enforce_types + @trace_method async def append_initial_message_sequence_to_in_context_messages_async( self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None ) -> PydanticAgentState: init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor) - @trace_method @enforce_types + @trace_method def update_agent( self, agent_id: str, @@ -773,8 +799,8 @@ class AgentManager: return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def update_agent_async( self, agent_id: str, @@ -1125,16 +1151,16 @@ class AgentManager: async with db_registry.async_session() as session: return await AgentModel.size_async(db_session=session, actor=actor) - @trace_method @enforce_types + @trace_method def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def get_agent_by_id_async( self, agent_id: str, @@ -1147,8 +1173,8 @@ class AgentManager: agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) return await agent.to_pydantic_async(include_relationships=include_relationships) - @trace_method @enforce_types + @trace_method async def get_agents_by_ids_async( self, agent_ids: list[str], @@ -1164,16 +1190,16 @@ class AgentManager: ) return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) - @trace_method @enforce_types + @trace_method def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" with db_registry.session() as session: agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method def delete_agent(self, agent_id: str, actor: PydanticUser) -> None: """ Deletes an agent and its associated relationships. @@ -1220,8 +1246,8 @@ class AgentManager: else: logger.debug(f"Agent with ID {agent_id} successfully hard deleted") - @trace_method @enforce_types + @trace_method async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None: """ Deletes an agent and its associated relationships. @@ -1270,8 +1296,8 @@ class AgentManager: else: logger.debug(f"Agent with ID {agent_id} successfully hard deleted") - @trace_method @enforce_types + @trace_method def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema: with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1279,8 +1305,8 @@ class AgentManager: data = schema.dump(agent) return AgentSchema(**data) - @trace_method @enforce_types + @trace_method def deserialize( self, serialized_agent: AgentSchema, @@ -1349,8 +1375,8 @@ class AgentManager: # ====================================================================================================================== # Per Agent Environment Variable Management # ====================================================================================================================== - @trace_method @enforce_types + @trace_method def _set_environment_variables( self, agent_id: str, @@ -1405,8 +1431,8 @@ class AgentManager: # Return the updated agent state return agent.to_pydantic() - @trace_method @enforce_types + @trace_method def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]: with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1422,20 +1448,20 @@ class AgentManager: # TODO: 2) These messages are ordered from oldest to newest # TODO: This can be fixed by having an actual relationship in the ORM for message_ids # TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query. - @trace_method @enforce_types + @trace_method def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor) - @trace_method @enforce_types + @trace_method def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor) - @trace_method @enforce_types + @trace_method async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor) @@ -1443,8 +1469,8 @@ class AgentManager: # TODO: This is duplicated below # TODO: This is legacy code and should be cleaned up # TODO: A lot of the memory "compilation" should be offset to a separate class - @trace_method @enforce_types + @trace_method def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState: """Rebuilds the system message with the latest memory object and any shared memory block updates @@ -1515,29 +1541,42 @@ class AgentManager: else: return agent_state - @trace_method + # TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish @enforce_types + @trace_method async def rebuild_system_prompt_async( - self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None - ) -> PydanticAgentState: + self, + agent_id: str, + actor: PydanticUser, + force=False, + update_timestamp=True, + tool_rules_solver: Optional[ToolRulesSolver] = None, + dry_run: bool = False, + ) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]: """Rebuilds the system message with the latest memory object and any shared memory block updates Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages """ - # Get the current agent state - agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources"], actor=actor) + num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id) + num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) + agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor) + + num_messages, num_archival_memories, agent_state = await asyncio.gather( + num_messages_task, + num_archival_memories_task, + agent_state_task, + ) + if not tool_rules_solver: tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) - curr_system_message = await self.get_system_message_async( - agent_id=agent_id, actor=actor - ) # this is the system + memory bank, not just the system prompt + curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) if curr_system_message is None: logger.warning(f"No system message found for agent {agent_state.id} and user {actor}") - return agent_state + return agent_state, curr_system_message, num_messages, num_archival_memories curr_system_message_openai = curr_system_message.to_openai_dict() @@ -1551,7 +1590,7 @@ class AgentManager: logger.debug( f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild" ) - return agent_state + return agent_state, curr_system_message, num_messages, num_archival_memories # If the memory didn't update, we probably don't want to update the timestamp inside # For example, if we're doing a system prompt swap, this should probably be False @@ -1561,9 +1600,6 @@ class AgentManager: # NOTE: a bit of a hack - we pull the timestamp from the message created_by memory_edit_timestamp = curr_system_message.created_at - num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id) - num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) - # update memory (TODO: potentially update recall/archival stats separately) new_system_message_str = compile_system_message( @@ -1582,63 +1618,67 @@ class AgentManager: logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") # Swap the system message out (only if there is a diff) - message = PydanticMessage.dict_to_message( + temp_message = PydanticMessage.dict_to_message( agent_id=agent_id, model=agent_state.llm_config.model, openai_message_dict={"role": "system", "content": new_system_message_str}, ) - message = await self.message_manager.update_message_by_id_async( - message_id=curr_system_message.id, - message_update=MessageUpdate(**message.model_dump()), - actor=actor, - ) - return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor) - else: - return agent_state + temp_message.id = curr_system_message.id + + if not dry_run: + await self.message_manager.update_message_by_id_async( + message_id=curr_system_message.id, + message_update=MessageUpdate(**temp_message.model_dump()), + actor=actor, + ) + else: + curr_system_message = temp_message + + return agent_state, curr_system_message, num_messages, num_archival_memories - @trace_method @enforce_types + @trace_method def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) - @trace_method @enforce_types + @trace_method async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) - @trace_method @enforce_types + @trace_method def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) - @trace_method @enforce_types + @trace_method def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids # TODO: How do we know this? new_messages = [message_ids[0]] # 0 is system message return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) - @trace_method @enforce_types + @trace_method def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids new_messages = self.message_manager.create_many_messages(messages, actor=actor) message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:] return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) - @trace_method @enforce_types + @trace_method def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: messages = self.message_manager.create_many_messages(messages, actor=actor) message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or [] message_ids += [m.id for m in messages] return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) - @trace_method @enforce_types + @trace_method async def append_to_in_context_messages_async( self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser ) -> PydanticAgentState: @@ -1648,8 +1688,8 @@ class AgentManager: message_ids += [m.id for m in messages] return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor) - @trace_method @enforce_types + @trace_method async def reset_messages_async( self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False ) -> PydanticAgentState: @@ -1712,8 +1752,8 @@ class AgentManager: else: return agent_state - @trace_method @enforce_types + @trace_method async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState: """ Update internal memory object and system prompt if there have been modifications. @@ -1756,12 +1796,12 @@ class AgentManager: # NOTE: don't do this since re-buildin the memory is handled at the start of the step # rebuild memory - this records the last edited timestamp of the memory # TODO: pass in update timestamp from block edit time - agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor) + await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor) return agent_state - @trace_method @enforce_types + @trace_method async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: # TODO: This will NOT work for new blocks/file blocks added intra-step block_ids = [b.id for b in agent_state.memory.blocks] @@ -1779,8 +1819,8 @@ class AgentManager: return agent_state - @trace_method @enforce_types + @trace_method async def refresh_file_blocks(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: file_blocks = await self.file_agent_manager.list_files_for_agent(agent_id=agent_state.id, actor=actor, return_as_blocks=True) agent_state.memory.file_blocks = [b for b in file_blocks if b is not None] @@ -1789,8 +1829,8 @@ class AgentManager: # ====================================================================================================================== # Source Management # ====================================================================================================================== - @trace_method @enforce_types + @trace_method async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: """ Attaches a source to an agent. @@ -1820,15 +1860,11 @@ class AgentManager: ) # Commit the changes - await agent.update_async(session, actor=actor) + agent = await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() - # Force rebuild of system prompt so that the agent is updated with passage count - pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) - - return pydantic_agent - - @trace_method @enforce_types + @trace_method def append_system_message(self, agent_id: str, content: str, actor: PydanticUser): # get the agent @@ -1840,8 +1876,8 @@ class AgentManager: # update agent in-context message IDs self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor) - @trace_method @enforce_types + @trace_method async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser): # get the agent @@ -1853,28 +1889,8 @@ class AgentManager: # update agent in-context message IDs await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor) - @trace_method @enforce_types - def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: - """ - Lists all sources attached to an agent. - - Args: - agent_id: ID of the agent to list sources for - actor: User performing the action - - Returns: - List[str]: List of source IDs attached to the agent - """ - with db_registry.session() as session: - # Verify agent exists and user has permission to access it - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Use the lazy-loaded relationship to get sources - return [source.to_pydantic() for source in agent.sources] - @trace_method - @enforce_types async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: """ Lists all sources attached to an agent. @@ -1885,44 +1901,34 @@ class AgentManager: Returns: List[str]: List of source IDs attached to the agent + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access """ async with db_registry.async_session() as session: - # Verify agent exists and user has permission to access it - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + # Validate agent exists and user has access + await self._validate_agent_exists_async(session, agent_id, actor) - # Use the lazy-loaded relationship to get sources - return [source.to_pydantic() for source in agent.sources] + # Use raw SQL to efficiently fetch sources - much faster than lazy loading + # Fast query without relationship loading + query = ( + select(SourceModel) + .join(SourcesAgents, SourceModel.id == SourcesAgents.source_id) + .where( + SourcesAgents.agent_id == agent_id, + SourceModel.organization_id == actor.organization_id, + SourceModel.is_deleted == False, + ) + .order_by(SourceModel.created_at.desc(), SourceModel.id) + ) + + result = await session.execute(query) + sources = result.scalars().all() + + return [source.to_pydantic() for source in sources] - @trace_method @enforce_types - def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: - """ - Detaches a source from an agent. - - Args: - agent_id: ID of the agent to detach the source from - source_id: ID of the source to detach - actor: User performing the action - """ - with db_registry.session() as session: - # Verify agent exists and user has permission to access it - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Remove the source from the relationship - remaining_sources = [s for s in agent.sources if s.id != source_id] - - if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship - logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}") - - # Update the sources relationship - agent.sources = remaining_sources - - # Commit the changes - agent.update(session, actor=actor) - return agent.to_pydantic() - @trace_method - @enforce_types async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: """ Detaches a source from an agent. @@ -1931,29 +1937,36 @@ class AgentManager: agent_id: ID of the agent to detach the source from source_id: ID of the source to detach actor: User performing the action + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access """ async with db_registry.async_session() as session: - # Verify agent exists and user has permission to access it - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + # Validate agent exists and user has access + await self._validate_agent_exists_async(session, agent_id, actor) - # Remove the source from the relationship - remaining_sources = [s for s in agent.sources if s.id != source_id] + # Check if the source is actually attached to this agent using junction table + attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id) + attachment_result = await session.execute(attachment_check_query) + attachment = attachment_result.scalar_one_or_none() - if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship + if not attachment: logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}") + else: + # Delete the association directly from the junction table + delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id) + await session.execute(delete_query) + await session.commit() - # Update the sources relationship - agent.sources = remaining_sources - - # Commit the changes - await agent.update_async(session, actor=actor) + # Get agent without loading relationships for return value + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) return await agent.to_pydantic_async() # ====================================================================================================================== # Block management # ====================================================================================================================== - @trace_method @enforce_types + @trace_method def get_block_with_label( self, agent_id: str, @@ -1968,8 +1981,8 @@ class AgentManager: return block.to_pydantic() raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") - @trace_method @enforce_types + @trace_method async def get_block_with_label_async( self, agent_id: str, @@ -1984,8 +1997,8 @@ class AgentManager: return block.to_pydantic() raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") - @trace_method @enforce_types + @trace_method async def modify_block_by_label_async( self, agent_id: str, @@ -2012,8 +2025,8 @@ class AgentManager: await block.update_async(session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method def update_block_with_label( self, agent_id: str, @@ -2037,8 +2050,8 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: """Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group.""" with db_registry.session() as session: @@ -2067,8 +2080,8 @@ class AgentManager: return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: """Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group.""" async with db_registry.async_session() as session: @@ -2103,8 +2116,8 @@ class AgentManager: return await agent.to_pydantic_async() - @trace_method @enforce_types + @trace_method def detach_block( self, agent_id: str, @@ -2124,8 +2137,8 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def detach_block_async( self, agent_id: str, @@ -2145,8 +2158,8 @@ class AgentManager: await agent.update_async(session, actor=actor) return await agent.to_pydantic_async() - @trace_method @enforce_types + @trace_method def detach_block_with_label( self, agent_id: str, @@ -2170,8 +2183,8 @@ class AgentManager: # Passage Management # ====================================================================================================================== - @trace_method @enforce_types + @trace_method def list_passages( self, actor: PydanticUser, @@ -2231,8 +2244,8 @@ class AgentManager: return [p.to_pydantic() for p in passages] - @trace_method @enforce_types + @trace_method async def list_passages_async( self, actor: PydanticUser, @@ -2292,8 +2305,8 @@ class AgentManager: return [p.to_pydantic() for p in passages] - @trace_method @enforce_types + @trace_method async def list_source_passages_async( self, actor: PydanticUser, @@ -2340,8 +2353,8 @@ class AgentManager: # Convert to Pydantic models return [p.to_pydantic() for p in passages] - @trace_method @enforce_types + @trace_method async def list_agent_passages_async( self, actor: PydanticUser, @@ -2384,8 +2397,8 @@ class AgentManager: # Convert to Pydantic models return [p.to_pydantic() for p in passages] - @trace_method @enforce_types + @trace_method def passage_size( self, actor: PydanticUser, @@ -2465,8 +2478,8 @@ class AgentManager: # ====================================================================================================================== # Tool Management # ====================================================================================================================== - @trace_method @enforce_types + @trace_method def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: """ Attaches a tool to an agent. @@ -2501,8 +2514,8 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: """ Attaches a tool to an agent. @@ -2537,8 +2550,8 @@ class AgentManager: await agent.update_async(session, actor=actor) return await agent.to_pydantic_async() - @trace_method @enforce_types + @trace_method async def attach_missing_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: """ Attaches missing core file tools to an agent. @@ -2569,8 +2582,8 @@ class AgentManager: return agent_state - @trace_method @enforce_types + @trace_method async def detach_all_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: """ Detach all core file tools from an agent. @@ -2596,8 +2609,8 @@ class AgentManager: return agent_state - @trace_method @enforce_types + @trace_method def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: """ Detaches a tool from an agent. @@ -2630,8 +2643,8 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() - @trace_method @enforce_types + @trace_method async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: """ Detaches a tool from an agent. @@ -2664,8 +2677,8 @@ class AgentManager: await agent.update_async(session, actor=actor) return await agent.to_pydantic_async() - @trace_method @enforce_types + @trace_method def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]: """ List all tools attached to an agent. @@ -2681,11 +2694,40 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return [tool.to_pydantic() for tool in agent.tools] + @enforce_types + @trace_method + async def list_attached_tools_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]: + """ + List all tools attached to an agent (async version with optimized performance). + Uses direct SQL queries to avoid SqlAlchemyBase overhead. + + Args: + agent_id: ID of the agent to list tools for. + actor: User performing the action. + + Returns: + List[PydanticTool]: List of tools attached to the agent. + """ + async with db_registry.async_session() as session: + # lightweight check for agent access + await self._validate_agent_exists_async(session, agent_id, actor) + + # direct query for tools via join - much more performant + query = ( + select(ToolModel) + .join(ToolsAgents, ToolModel.id == ToolsAgents.tool_id) + .where(ToolsAgents.agent_id == agent_id, ToolModel.organization_id == actor.organization_id) + ) + + result = await session.execute(query) + tools = result.scalars().all() + return [tool.to_pydantic() for tool in tools] + # ====================================================================================================================== # Tag Management # ====================================================================================================================== - @trace_method @enforce_types + @trace_method def list_tags( self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None ) -> List[str]: @@ -2719,8 +2761,8 @@ class AgentManager: results = [tag[0] for tag in query.all()] return results - @trace_method @enforce_types + @trace_method async def list_tags_async( self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None ) -> List[str]: @@ -2759,8 +2801,11 @@ class AgentManager: results = [row[0] for row in result.all()] return results + @trace_method async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: - agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) + agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async( + agent_id=agent_id, actor=actor, force=True, dry_run=True + ) calculator = ContextWindowCalculator() if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION" and agent_state.llm_config.model_endpoint_type == "anthropic": @@ -2776,5 +2821,7 @@ class AgentManager: actor=actor, token_counter=token_counter, message_manager=self.message_manager, - passage_manager=self.passage_manager, + system_message_compiled=system_message, + num_archival_memories=num_archival_memories, + num_messages=num_messages, ) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 6aa89144..0c0203bd 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -23,8 +23,8 @@ logger = get_logger(__name__) class BlockManager: """Manager class to handle business logic related to Blocks.""" - @trace_method @enforce_types + @trace_method def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: """Create a new block based on the Block schema.""" db_block = self.get_block_by_id(block.id, actor) @@ -38,8 +38,8 @@ class BlockManager: block.create(session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: """Create a new block based on the Block schema.""" db_block = await self.get_block_by_id_async(block.id, actor) @@ -53,8 +53,8 @@ class BlockManager: await block.create_async(session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: """ Batch-create multiple Blocks in one transaction for better performance. @@ -77,8 +77,8 @@ class BlockManager: # Convert back to Pydantic return [m.to_pydantic() for m in created_models] - @trace_method @enforce_types + @trace_method async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: """ Batch-create multiple Blocks in one transaction for better performance. @@ -101,8 +101,8 @@ class BlockManager: # Convert back to Pydantic return [m.to_pydantic() for m in created_models] - @trace_method @enforce_types + @trace_method def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" # Safety check for block @@ -117,8 +117,8 @@ class BlockManager: block.update(db_session=session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" # Safety check for block @@ -133,8 +133,8 @@ class BlockManager: await block.update_async(db_session=session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: """Delete a block by its ID.""" with db_registry.session() as session: @@ -142,8 +142,8 @@ class BlockManager: block.hard_delete(db_session=session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method async def delete_block_async(self, block_id: str, actor: PydanticUser) -> PydanticBlock: """Delete a block by its ID.""" async with db_registry.async_session() as session: @@ -151,8 +151,8 @@ class BlockManager: await block.hard_delete_async(db_session=session, actor=actor) return block.to_pydantic() - @trace_method @enforce_types + @trace_method async def get_blocks_async( self, actor: PydanticUser, @@ -214,8 +214,8 @@ class BlockManager: return [block.to_pydantic() for block in blocks] - @trace_method @enforce_types + @trace_method def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: """Retrieve a block by its name.""" with db_registry.session() as session: @@ -225,8 +225,8 @@ class BlockManager: except NoResultFound: return None - @trace_method @enforce_types + @trace_method async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: """Retrieve a block by its name.""" async with db_registry.async_session() as session: @@ -236,8 +236,8 @@ class BlockManager: except NoResultFound: return None - @trace_method @enforce_types + @trace_method async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" from sqlalchemy import select @@ -284,8 +284,8 @@ class BlockManager: return pydantic_blocks - @trace_method @enforce_types + @trace_method async def get_agents_for_block_async( self, block_id: str, @@ -301,8 +301,8 @@ class BlockManager: agents = await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents_orm]) return agents - @trace_method @enforce_types + @trace_method async def size_async(self, actor: PydanticUser) -> int: """ Get the total count of blocks for the given user. @@ -312,8 +312,8 @@ class BlockManager: # Block History Functions - @trace_method @enforce_types + @trace_method def checkpoint_block( self, block_id: str, @@ -416,8 +416,8 @@ class BlockManager: updated_block = block.update(db_session=session, actor=actor, no_commit=True) return updated_block - @trace_method @enforce_types + @trace_method def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: """ Move the block to the immediately previous checkpoint in BlockHistory. @@ -459,8 +459,8 @@ class BlockManager: session.commit() return block.to_pydantic() - @trace_method @enforce_types + @trace_method def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: """ Move the block to the next checkpoint if it exists. @@ -498,8 +498,8 @@ class BlockManager: session.commit() return block.to_pydantic() - @trace_method @enforce_types + @trace_method async def bulk_update_block_values_async( self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False ) -> Optional[List[PydanticBlock]]: diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py index 47a9aacd..c405d289 100644 --- a/letta/services/context_window_calculator/context_window_calculator.py +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -4,11 +4,14 @@ from typing import Any, List, Optional, Tuple from openai.types.beta.function_tool import FunctionTool as OpenAITool from letta.log import get_logger +from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview +from letta.schemas.message import Message from letta.schemas.user import User as PydanticUser from letta.services.context_window_calculator.token_counter import TokenCounter +from letta.services.message_manager import MessageManager logger = get_logger(__name__) @@ -56,16 +59,18 @@ class ContextWindowCalculator: return None, 1 async def calculate_context_window( - self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any + self, + agent_state: AgentState, + actor: PydanticUser, + token_counter: TokenCounter, + message_manager: MessageManager, + system_message_compiled: Message, + num_archival_memories: int, + num_messages: int, ) -> ContextWindowOverview: """Calculate context window information using the provided token counter""" - - # Fetch data concurrently - (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), - passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id), - message_manager.size_async(actor=actor, agent_id=agent_state.id), - ) + messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids[1:], actor=actor) + in_context_messages = [system_message_compiled] + messages # Convert messages to appropriate format converted_messages = token_counter.convert_messages(in_context_messages) @@ -128,8 +133,8 @@ class ContextWindowCalculator: return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(in_context_messages), - num_archival_memory=passage_manager_size, - num_recall_memory=message_manager_size, + num_archival_memory=num_archival_memories, + num_recall_memory=num_messages, num_tokens_external_memory_summary=num_tokens_external_memory_summary, external_memory_summary=external_memory_summary, # top-level information diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 3e1de4f7..52e43244 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -1,7 +1,11 @@ +import hashlib +import json from abc import ABC, abstractmethod from typing import Any, Dict, List +from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient +from letta.otel.tracing import trace_method from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.utils import count_tokens @@ -33,16 +37,34 @@ class AnthropicTokenCounter(TokenCounter): self.client = anthropic_client self.model = model + @trace_method + @async_redis_cache( + key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_text_tokens(self, text: str) -> int: if not text: return 0 return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}]) + @trace_method + @async_redis_cache( + key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: if not messages: return 0 return await self.client.count_tokens(model=self.model, messages=messages) + @trace_method + @async_redis_cache( + key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: if not tools: return 0 @@ -58,11 +80,23 @@ class TiktokenCounter(TokenCounter): def __init__(self, model: str): self.model = model + @trace_method + @async_redis_cache( + key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_text_tokens(self, text: str) -> int: if not text: return 0 return count_tokens(text) + @trace_method + @async_redis_cache( + key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: if not messages: return 0 @@ -70,6 +104,12 @@ class TiktokenCounter(TokenCounter): return num_tokens_from_messages(messages=messages, model=self.model) + @trace_method + @async_redis_cache( + key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: if not tools: return 0 diff --git a/letta/services/file_processor/chunker/line_chunker.py b/letta/services/file_processor/chunker/line_chunker.py index c06f024b..fe5ed031 100644 --- a/letta/services/file_processor/chunker/line_chunker.py +++ b/letta/services/file_processor/chunker/line_chunker.py @@ -40,6 +40,10 @@ class LineChunker: def _chunk_by_lines(self, text: str, preserve_indentation: bool = False) -> List[str]: """Traditional line-based chunking for code and structured data""" + # early stop, can happen if the there's nothing on a specific file + if not text: + return [] + lines = [] for line in text.splitlines(): if preserve_indentation: @@ -57,6 +61,10 @@ class LineChunker: def _chunk_by_sentences(self, text: str) -> List[str]: """Sentence-based chunking for documentation and markup""" + # early stop, can happen if the there's nothing on a specific file + if not text: + return [] + # Simple sentence splitting on periods, exclamation marks, and question marks # followed by whitespace or end of string sentence_pattern = r"(?<=[.!?])\s+(?=[A-Z])" @@ -75,6 +83,10 @@ class LineChunker: def _chunk_by_characters(self, text: str, target_line_length: int = 100) -> List[str]: """Character-based wrapping for prose text""" + # early stop, can happen if the there's nothing on a specific file + if not text: + return [] + words = text.split() lines = [] current_line = [] @@ -110,6 +122,11 @@ class LineChunker: strategy = self._determine_chunking_strategy(file_metadata) text = file_metadata.content + # early stop, can happen if the there's nothing on a specific file + if not text: + logger.warning(f"File ({file_metadata}) has no content") + return [] + # Apply the appropriate chunking strategy if strategy == ChunkingStrategy.DOCUMENTATION: content_lines = self._chunk_by_sentences(text) diff --git a/letta/services/file_processor/embedder/openai_embedder.py b/letta/services/file_processor/embedder/openai_embedder.py index 5a888549..b55ba936 100644 --- a/letta/services/file_processor/embedder/openai_embedder.py +++ b/letta/services/file_processor/embedder/openai_embedder.py @@ -25,7 +25,6 @@ class OpenAIEmbedder(BaseEmbedder): else EmbeddingConfig.default_config(model_name="letta") ) self.embedding_config = embedding_config or self.default_embedding_config - self.max_concurrent_requests = 20 # TODO: Unify to global OpenAI client self.client: OpenAIClient = cast( @@ -48,9 +47,55 @@ class OpenAIEmbedder(BaseEmbedder): "embedding_endpoint_type": self.embedding_config.embedding_endpoint_type, }, ) - embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config) - log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)}) - return [(idx, e) for idx, e in zip(batch_indices, embeddings)] + + try: + embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config) + log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)}) + return [(idx, e) for idx, e in zip(batch_indices, embeddings)] + except Exception as e: + # if it's a token limit error and we can split, do it + if self._is_token_limit_error(e) and len(batch) > 1: + logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying") + log_event( + "embedder.batch_split_retry", + { + "original_batch_size": len(batch), + "error": str(e), + "split_size": len(batch) // 2, + }, + ) + + # split batch in half + mid = len(batch) // 2 + batch1 = batch[:mid] + batch1_indices = batch_indices[:mid] + batch2 = batch[mid:] + batch2_indices = batch_indices[mid:] + + # retry with smaller batches + result1 = await self._embed_batch(batch1, batch1_indices) + result2 = await self._embed_batch(batch2, batch2_indices) + + return result1 + result2 + else: + # re-raise for other errors or if batch size is already 1 + raise + + def _is_token_limit_error(self, error: Exception) -> bool: + """Check if the error is due to token limit exceeded""" + # convert to string and check for token limit patterns + error_str = str(error).lower() + + # TODO: This is quite brittle, works for now + # check for the specific patterns we see in token limit errors + is_token_limit = ( + "max_tokens_per_request" in error_str + or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str) + or "token limit" in error_str + or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str) + ) + + return is_token_limit @trace_method async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: @@ -100,7 +145,7 @@ class OpenAIEmbedder(BaseEmbedder): log_event( "embedder.concurrent_processing_started", - {"concurrent_tasks": len(tasks), "max_concurrent_requests": self.max_concurrent_requests}, + {"concurrent_tasks": len(tasks)}, ) results = await asyncio.gather(*tasks) log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)}) diff --git a/letta/services/files_agents_manager.py b/letta/services/files_agents_manager.py index a4abab31..0264f8dc 100644 --- a/letta/services/files_agents_manager.py +++ b/letta/services/files_agents_manager.py @@ -29,6 +29,7 @@ class FileAgentManager: agent_id: str, file_id: str, file_name: str, + source_id: str, actor: PydanticUser, is_open: bool = True, visible_content: Optional[str] = None, @@ -47,7 +48,12 @@ class FileAgentManager: if is_open: # Use the efficient LRU + open method closed_files, was_already_open = await self.enforce_max_open_files_and_open( - agent_id=agent_id, file_id=file_id, file_name=file_name, actor=actor, visible_content=visible_content or "" + agent_id=agent_id, + file_id=file_id, + file_name=file_name, + source_id=source_id, + actor=actor, + visible_content=visible_content or "", ) # Get the updated file agent to return @@ -85,6 +91,7 @@ class FileAgentManager: agent_id=agent_id, file_id=file_id, file_name=file_name, + source_id=source_id, organization_id=actor.organization_id, is_open=is_open, visible_content=visible_content, @@ -327,7 +334,7 @@ class FileAgentManager: @enforce_types @trace_method async def enforce_max_open_files_and_open( - self, *, agent_id: str, file_id: str, file_name: str, actor: PydanticUser, visible_content: str + self, *, agent_id: str, file_id: str, file_name: str, source_id: str, actor: PydanticUser, visible_content: str ) -> tuple[List[str], bool]: """ Efficiently handle LRU eviction and file opening in a single transaction. @@ -336,6 +343,7 @@ class FileAgentManager: agent_id: ID of the agent file_id: ID of the file to open file_name: Name of the file to open + source_id: ID of the source (denormalized from files.source_id) actor: User performing the action visible_content: Content to set for the opened file @@ -418,6 +426,7 @@ class FileAgentManager: agent_id=agent_id, file_id=file_id, file_name=file_name, + source_id=source_id, organization_id=actor.organization_id, is_open=True, visible_content=visible_content, @@ -516,6 +525,7 @@ class FileAgentManager: agent_id=agent_id, file_id=meta.id, file_name=meta.file_name, + source_id=meta.source_id, organization_id=actor.organization_id, is_open=is_now_open, visible_content=vc, diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index d2b0a501..2be87789 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -19,8 +19,8 @@ from letta.utils import enforce_types class GroupManager: - @trace_method @enforce_types + @trace_method def list_groups( self, actor: PydanticUser, @@ -45,22 +45,22 @@ class GroupManager: ) return [group.to_pydantic() for group in groups] - @trace_method @enforce_types + @trace_method def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup: with db_registry.session() as session: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) return group.to_pydantic() - @trace_method @enforce_types + @trace_method async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) return group.to_pydantic() - @trace_method @enforce_types + @trace_method def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup: with db_registry.session() as session: new_group = GroupModel() @@ -150,8 +150,8 @@ class GroupManager: await new_group.create_async(session, actor=actor) return new_group.to_pydantic() - @trace_method @enforce_types + @trace_method async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) @@ -213,16 +213,16 @@ class GroupManager: await group.update_async(session, actor=actor) return group.to_pydantic() - @trace_method @enforce_types + @trace_method def delete_group(self, group_id: str, actor: PydanticUser) -> None: with db_registry.session() as session: # Retrieve the agent group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) group.hard_delete(session) - @trace_method @enforce_types + @trace_method def list_group_messages( self, actor: PydanticUser, @@ -258,8 +258,8 @@ class GroupManager: return messages - @trace_method @enforce_types + @trace_method def reset_messages(self, group_id: str, actor: PydanticUser) -> None: with db_registry.session() as session: # Ensure group is loadable by user @@ -272,8 +272,8 @@ class GroupManager: session.commit() - @trace_method @enforce_types + @trace_method def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int: with db_registry.session() as session: # Ensure group is loadable by user @@ -284,8 +284,8 @@ class GroupManager: group.update(session, actor=actor) return group.turns_counter - @trace_method @enforce_types + @trace_method async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int: async with db_registry.async_session() as session: # Ensure group is loadable by user @@ -309,8 +309,8 @@ class GroupManager: return prev_last_processed_message_id - @trace_method @enforce_types + @trace_method async def get_last_processed_message_id_and_update_async( self, group_id: str, last_processed_message_id: str, actor: PydanticUser ) -> str: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 9e3ee4d2..b3cd2c04 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,8 +1,12 @@ import asyncio from typing import List, Optional +from sqlalchemy import select + +from letta.orm import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.source import Source as SourceModel +from letta.orm.sources_agents import SourcesAgents from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.source import Source as PydanticSource @@ -104,9 +108,21 @@ class SourceManager: # Verify source exists and user has permission to access it source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) - # The agents relationship is already loaded due to lazy="selectin" in the Source model - # and will be properly filtered by organization_id due to the OrganizationMixin - agents_orm = source.agents + # Use junction table query instead of relationship to avoid performance issues + query = ( + select(AgentModel) + .join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id) + .where( + SourcesAgents.source_id == source_id, + AgentModel.organization_id == actor.organization_id if actor else True, + AgentModel.is_deleted == False, + ) + .order_by(AgentModel.created_at.desc(), AgentModel.id) + ) + + result = await session.execute(query) + agents_orm = result.scalars().all() + return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) # TODO: We make actor optional for now, but should most likely be enforced due to security reasons diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 11c487d0..8883890f 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -188,7 +188,7 @@ class LettaCoreToolExecutor(ToolExecutor): Append to the contents of core memory. Args: - label (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited. content (str): Content to write to the memory. All unicode (including emojis) are supported. Returns: @@ -214,7 +214,7 @@ class LettaCoreToolExecutor(ToolExecutor): Replace the contents of core memory. To delete memories, use an empty string for new_content. Args: - label (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited. old_content (str): String to replace. Must be an exact match. new_content (str): Content to write to the memory. All unicode (including emojis) are supported. diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index 4815243a..b56b2253 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -180,7 +180,12 @@ class LettaFileToolExecutor(ToolExecutor): # Handle LRU eviction and file opening closed_files, was_already_open = await self.files_agents_manager.enforce_max_open_files_and_open( - agent_id=agent_state.id, file_id=file_id, file_name=file_name, actor=self.actor, visible_content=visible_content + agent_id=agent_state.id, + file_id=file_id, + file_name=file_name, + source_id=file.source_id, + actor=self.actor, + visible_content=visible_content, ) opened_files.append(file_name) diff --git a/poetry.lock b/poetry.lock index f771114b..a27db5d9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3591,14 +3591,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"] [[package]] name = "letta-client" -version = "0.1.198" +version = "0.1.197" description = "" optional = false python-versions = "<4.0,>=3.8" groups = ["main"] files = [ - {file = "letta_client-0.1.198-py3-none-any.whl", hash = "sha256:08bbc238b128da2552b2a6e54feb3294794b5586e0962ce0bb95bb525109f58f"}, - {file = "letta_client-0.1.198.tar.gz", hash = "sha256:990c9132423e2955d9c7f7549e5064b2366616232d270e5927788cddba4ef9da"}, + {file = "letta_client-0.1.197-py3-none-any.whl", hash = "sha256:b01eab01ff87a34e79622cd8d3a2f3da56b6bd730312a268d968576671f6cce0"}, + {file = "letta_client-0.1.197.tar.gz", hash = "sha256:579571623ccec81422087cb7957d5f580fa0b7a53f8495596e5c56d0213220d8"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index ac4991af..57a8f27a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.8.13" +version = "0.8.14" packages = [ {include = "letta"}, ] @@ -116,7 +116,6 @@ google = ["google-genai"] desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"] all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"] - [tool.poetry.group.dev.dependencies] black = "^24.4.2" ipykernel = "^6.29.5" diff --git a/test_agent_serialization.json b/test_agent_serialization.json index 6efaab99..818d9e7e 100644 --- a/test_agent_serialization.json +++ b/test_agent_serialization.json @@ -306,7 +306,7 @@ "properties": { "label": { "type": "string", - "description": "Section of the memory to be edited (persona or human)." + "description": "Section of the memory to be edited." }, "content": { "type": "string", @@ -343,7 +343,7 @@ "properties": { "label": { "type": "string", - "description": "Section of the memory to be edited (persona or human)." + "description": "Section of the memory to be edited." }, "old_content": { "type": "string", diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 82abb3a7..950e44d0 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -237,7 +237,7 @@ def validate_context_window_overview( # 16. Check attached file is visible if attached_file: - assert attached_file.visible_content in overview.core_memory + assert attached_file.visible_content in overview.core_memory, "File must be attached in core memory" assert '" in overview.core_memory diff --git a/tests/test_file_processor.py b/tests/test_file_processor.py new file mode 100644 index 00000000..e2448e2e --- /dev/null +++ b/tests/test_file_processor.py @@ -0,0 +1,219 @@ +from unittest.mock import AsyncMock, Mock, patch + +import openai +import pytest + +from letta.errors import ErrorCode, LLMBadRequestError +from letta.schemas.embedding_config import EmbeddingConfig +from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder + + +class TestOpenAIEmbedder: + """Test suite for OpenAI embedder functionality""" + + @pytest.fixture + def mock_user(self): + """Create a mock user for testing""" + user = Mock() + user.organization_id = "test_org_id" + return user + + @pytest.fixture + def embedding_config(self): + """Create a test embedding config""" + return EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=3, # small dimension for testing + embedding_chunk_size=300, + batch_size=2, # small batch size for testing + ) + + @pytest.fixture + def embedder(self, embedding_config): + """Create OpenAI embedder with test config""" + with patch("letta.services.file_processor.embedder.openai_embedder.LLMClient.create") as mock_create: + mock_client = Mock() + mock_client.handle_llm_error = Mock() + mock_create.return_value = mock_client + + embedder = OpenAIEmbedder(embedding_config) + embedder.client = mock_client + return embedder + + @pytest.mark.asyncio + async def test_successful_embedding_generation(self, embedder, mock_user): + """Test successful embedding generation for normal cases""" + # mock successful embedding response + mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + embedder.client.request_embeddings = AsyncMock(return_value=mock_embeddings) + + chunks = ["chunk 1", "chunk 2"] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert len(passages) == 2 + assert passages[0].text == "chunk 1" + assert passages[1].text == "chunk 2" + # embeddings are padded to MAX_EMBEDDING_DIM, so check first 3 values + assert passages[0].embedding[:3] == [0.1, 0.2, 0.3] + assert passages[1].embedding[:3] == [0.4, 0.5, 0.6] + assert passages[0].file_id == file_id + assert passages[0].source_id == source_id + + @pytest.mark.asyncio + async def test_token_limit_retry_splits_batch(self, embedder, mock_user): + """Test that token limit errors trigger batch splitting and retry""" + # create a mock token limit error + mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}} + token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body) + + # first call fails with token limit, subsequent calls succeed + call_count = 0 + + async def mock_request_embeddings(inputs, embedding_config): + nonlocal call_count + call_count += 1 + if call_count == 1 and len(inputs) == 4: # first call with full batch + raise token_limit_error + elif len(inputs) == 2: # split batches succeed + return [[0.1, 0.2], [0.3, 0.4]] if call_count == 2 else [[0.5, 0.6], [0.7, 0.8]] + else: + return [[0.1, 0.2]] * len(inputs) + + embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings) + + chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + # should still get all 4 passages despite the retry + assert len(passages) == 4 + assert all(len(p.embedding) == 4096 for p in passages) # padded to MAX_EMBEDDING_DIM + # verify multiple calls were made (original + retries) + assert call_count >= 2 + + @pytest.mark.asyncio + async def test_token_limit_error_detection(self, embedder): + """Test various token limit error detection patterns""" + # test openai BadRequestError with proper structure + mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}} + openai_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body) + assert embedder._is_token_limit_error(openai_error) is True + + # test error with message but no code + mock_error_body_no_code = {"error": {"message": "max_tokens_per_request exceeded"}} + openai_error_no_code = openai.BadRequestError( + message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body_no_code + ) + assert embedder._is_token_limit_error(openai_error_no_code) is True + + # test fallback string detection + generic_error = Exception("Requested 100000 tokens, max 50000 tokens per request") + assert embedder._is_token_limit_error(generic_error) is True + + # test non-token errors + other_error = Exception("Some other error") + assert embedder._is_token_limit_error(other_error) is False + + auth_error = openai.AuthenticationError( + message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}} + ) + assert embedder._is_token_limit_error(auth_error) is False + + @pytest.mark.asyncio + async def test_non_token_error_handling(self, embedder, mock_user): + """Test that non-token errors are properly handled and re-raised""" + # create a non-token error + auth_error = openai.AuthenticationError( + message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}} + ) + + # mock handle_llm_error to return a standardized error + handled_error = LLMBadRequestError(message="Handled error", code=ErrorCode.UNAUTHENTICATED) + embedder.client.handle_llm_error.return_value = handled_error + embedder.client.request_embeddings = AsyncMock(side_effect=auth_error) + + chunks = ["chunk 1"] + file_id = "test_file" + source_id = "test_source" + + with pytest.raises(LLMBadRequestError) as exc_info: + await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert exc_info.value == handled_error + embedder.client.handle_llm_error.assert_called_once_with(auth_error) + + @pytest.mark.asyncio + async def test_single_item_batch_no_retry(self, embedder, mock_user): + """Test that single-item batches don't retry on token limit errors""" + # create a token limit error + mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}} + token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body) + + handled_error = LLMBadRequestError(message="Handled token limit error", code=ErrorCode.INVALID_ARGUMENT) + embedder.client.handle_llm_error.return_value = handled_error + embedder.client.request_embeddings = AsyncMock(side_effect=token_limit_error) + + chunks = ["very long chunk that exceeds token limit"] + file_id = "test_file" + source_id = "test_source" + + with pytest.raises(LLMBadRequestError) as exc_info: + await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert exc_info.value == handled_error + embedder.client.handle_llm_error.assert_called_once_with(token_limit_error) + + @pytest.mark.asyncio + async def test_empty_chunks_handling(self, embedder, mock_user): + """Test handling of empty chunks list""" + chunks = [] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + assert passages == [] + # should not call request_embeddings for empty input + embedder.client.request_embeddings.assert_not_called() + + @pytest.mark.asyncio + async def test_embedding_order_preservation(self, embedder, mock_user): + """Test that embedding order is preserved even with retries""" + # set up embedder to split batches (batch_size=2) + embedder.embedding_config.batch_size = 2 + + # mock responses for each batch + async def mock_request_embeddings(inputs, embedding_config): + # return embeddings that correspond to input order + if inputs == ["chunk 1", "chunk 2"]: + return [[0.1, 0.1], [0.2, 0.2]] + elif inputs == ["chunk 3", "chunk 4"]: + return [[0.3, 0.3], [0.4, 0.4]] + else: + return [[0.1, 0.1]] * len(inputs) + + embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings) + + chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"] + file_id = "test_file" + source_id = "test_source" + + passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user) + + # verify order is preserved + assert len(passages) == 4 + assert passages[0].text == "chunk 1" + assert passages[0].embedding[:2] == [0.1, 0.1] # check first 2 values before padding + assert passages[1].text == "chunk 2" + assert passages[1].embedding[:2] == [0.2, 0.2] + assert passages[2].text == "chunk 3" + assert passages[2].embedding[:2] == [0.3, 0.3] + assert passages[3].text == "chunk 4" + assert passages[3].embedding[:2] == [0.4, 0.4] diff --git a/tests/test_managers.py b/tests/test_managers.py index 2048533e..f8833f8f 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -673,6 +673,7 @@ async def file_attachment(server, default_user, sarah_agent, default_file): agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="initial", ) @@ -903,6 +904,7 @@ async def test_get_context_window_basic( agent_id=created_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="hello", ) @@ -7221,6 +7223,7 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="hello", ) @@ -7243,6 +7246,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, visible_content="first", ) @@ -7252,6 +7256,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, is_open=False, visible_content="second", @@ -7326,15 +7331,28 @@ async def test_list_files_and_agents( ): # default_file ↔ charles (open) await server.file_agent_manager.attach_file( - agent_id=charles_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user + agent_id=charles_agent.id, + file_id=default_file.id, + file_name=default_file.file_name, + source_id=default_file.source_id, + actor=default_user, ) # default_file ↔ sarah (open) await server.file_agent_manager.attach_file( - agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user + agent_id=sarah_agent.id, + file_id=default_file.id, + file_name=default_file.file_name, + source_id=default_file.source_id, + actor=default_user, ) # another_file ↔ sarah (closed) await server.file_agent_manager.attach_file( - agent_id=sarah_agent.id, file_id=another_file.id, file_name=another_file.file_name, actor=default_user, is_open=False + agent_id=sarah_agent.id, + file_id=another_file.id, + file_name=another_file.file_name, + source_id=another_file.source_id, + actor=default_user, + is_open=False, ) files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user) @@ -7384,6 +7402,7 @@ async def test_org_scoping( agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, + source_id=default_file.source_id, actor=default_user, ) @@ -7420,6 +7439,7 @@ async def test_mark_access_bulk(server, default_user, sarah_agent, default_sourc agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"content for {file.file_name}", ) @@ -7478,6 +7498,7 @@ async def test_lru_eviction_on_attach(server, default_user, sarah_agent, default agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"content for {file.file_name}", ) @@ -7530,6 +7551,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa agent_id=sarah_agent.id, file_id=files[i].id, file_name=files[i].file_name, + source_id=files[i].source_id, actor=default_user, visible_content=f"content for {files[i].file_name}", ) @@ -7539,6 +7561,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, + source_id=files[-1].source_id, actor=default_user, is_open=False, visible_content=f"content for {files[-1].file_name}", @@ -7555,7 +7578,12 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa # Now "open" the last file using the efficient method closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open( - agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content" + agent_id=sarah_agent.id, + file_id=files[-1].id, + file_name=files[-1].file_name, + source_id=files[-1].source_id, + actor=default_user, + visible_content="updated content", ) # Should have closed 1 file (the oldest one) @@ -7603,6 +7631,7 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"content for {file.file_name}", ) @@ -7617,7 +7646,12 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa # "Reopen" the last file (which is already open) closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open( - agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content" + agent_id=sarah_agent.id, + file_id=files[-1].id, + file_name=files[-1].file_name, + source_id=files[-1].source_id, + actor=default_user, + visible_content="updated content", ) # Should not have closed any files since we're within the limit @@ -7645,7 +7679,12 @@ async def test_last_accessed_at_updates_correctly(server, default_user, sarah_ag file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text="test content") file_agent, closed_files = await server.file_agent_manager.attach_file( - agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, actor=default_user, visible_content="initial content" + agent_id=sarah_agent.id, + file_id=file.id, + file_name=file.file_name, + source_id=file.source_id, + actor=default_user, + visible_content="initial content", ) initial_time = file_agent.last_accessed_at @@ -7777,6 +7816,7 @@ async def test_attach_files_bulk_lru_eviction(server, default_user, sarah_agent, agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, + source_id=file.source_id, actor=default_user, visible_content=f"existing content {i}", ) @@ -7842,6 +7882,7 @@ async def test_attach_files_bulk_mixed_existing_new(server, default_user, sarah_ agent_id=sarah_agent.id, file_id=existing_file.id, file_name=existing_file.file_name, + source_id=existing_file.source_id, actor=default_user, visible_content="old content", is_open=False, # Start as closed diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 7ff37851..71a72322 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1033,3 +1033,36 @@ def test_preview_payload(client: LettaSDKClient, agent): assert payload["user"].startswith("user-") print(payload) + + +def test_agent_tools_list(client: LettaSDKClient): + """Test the optimized agent tools list endpoint for correctness.""" + # Create a test agent + agent_state = client.agents.create( + name="test_agent_tools_list", + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a helpful assistant.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + include_base_tools=True, + ) + + try: + # Test basic functionality + tools = client.agents.tools.list(agent_id=agent_state.id) + assert len(tools) > 0, "Agent should have base tools attached" + + # Verify tool objects have expected attributes + for tool in tools: + assert hasattr(tool, "id"), "Tool should have id attribute" + assert hasattr(tool, "name"), "Tool should have name attribute" + assert tool.id is not None, "Tool id should not be None" + assert tool.name is not None, "Tool name should not be None" + + finally: + # Clean up + client.agents.delete(agent_id=agent_state.id) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5b0e724c..5da46f61 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -464,9 +464,8 @@ def test_line_chunker_edge_case_empty_file(): file = FileMetadata(file_name="empty.py", source_id="test_source", content="") chunker = LineChunker() - # Test requesting lines from empty file - with pytest.raises(ValueError, match="File empty.py has only 0 lines, but requested offset 1 is out of range"): - chunker.chunk_text(file, start=0, end=1, validate_range=True) + # no error + chunker.chunk_text(file, start=0, end=1, validate_range=True) def test_line_chunker_edge_case_single_line():