chore: bump version 0.8.14 (#2720)

Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
Co-authored-by: cpacker <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shub@letta.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
This commit is contained in:
cthomas
2025-07-14 11:03:15 -07:00
committed by GitHub
parent 7132104f4d
commit 33eaabb04a
35 changed files with 944 additions and 310 deletions

View File

@@ -0,0 +1,52 @@
"""Write source_id directly to files agents
Revision ID: 495f3f474131
Revises: 47d2277e530d
Create Date: 2025-07-10 17:14:45.154738
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "495f3f474131"
down_revision: Union[str, None] = "47d2277e530d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Step 1: Add the column as nullable first
op.add_column("files_agents", sa.Column("source_id", sa.String(), nullable=True))
# Step 2: Backfill source_id from files table
connection = op.get_bind()
connection.execute(
sa.text(
"""
UPDATE files_agents
SET source_id = files.source_id
FROM files
WHERE files_agents.file_id = files.id
"""
)
)
# Step 3: Make the column NOT NULL now that it's populated
op.alter_column("files_agents", "source_id", nullable=False)
# Step 4: Add the foreign key constraint
op.create_foreign_key(None, "files_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "files_agents", type_="foreignkey")
op.drop_column("files_agents", "source_id")
# ### end Alembic commands ###

View File

@@ -5,7 +5,7 @@ try:
__version__ = version("letta")
except PackageNotFoundError:
# Fallback for development installations
__version__ = "0.8.13"
__version__ = "0.8.14"
if os.environ.get("LETTA_VERSION"):
__version__ = os.environ["LETTA_VERSION"]

View File

@@ -372,3 +372,9 @@ PINECONE_METRIC = "cosine"
PINECONE_CLOUD = "aws"
PINECONE_REGION = "us-east-1"
PINECONE_MAX_BATCH_SIZE = 96
# retry configuration
PINECONE_MAX_RETRY_ATTEMPTS = 3
PINECONE_RETRY_BASE_DELAY = 1.0 # seconds
PINECONE_RETRY_MAX_DELAY = 60.0 # seconds
PINECONE_RETRY_BACKOFF_FACTOR = 2.0

View File

@@ -135,7 +135,7 @@ def core_memory_append(agent_state: "AgentState", label: str, content: str) -> O
Append to the contents of core memory.
Args:
label (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited.
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
@@ -152,7 +152,7 @@ def core_memory_replace(agent_state: "AgentState", label: str, old_content: str,
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
label (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited.
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.

View File

@@ -1,8 +1,20 @@
import asyncio
import random
import time
from functools import wraps
from typing import Any, Dict, List
from letta.otel.tracing import trace_method
try:
from pinecone import IndexEmbed, PineconeAsyncio
from pinecone.exceptions.exceptions import NotFoundException
from pinecone.exceptions.exceptions import (
ForbiddenException,
NotFoundException,
PineconeApiException,
ServiceException,
UnauthorizedException,
)
PINECONE_AVAILABLE = True
except ImportError:
@@ -12,8 +24,12 @@ from letta.constants import (
PINECONE_CLOUD,
PINECONE_EMBEDDING_MODEL,
PINECONE_MAX_BATCH_SIZE,
PINECONE_MAX_RETRY_ATTEMPTS,
PINECONE_METRIC,
PINECONE_REGION,
PINECONE_RETRY_BACKOFF_FACTOR,
PINECONE_RETRY_BASE_DELAY,
PINECONE_RETRY_MAX_DELAY,
PINECONE_TEXT_FIELD_NAME,
)
from letta.log import get_logger
@@ -23,6 +39,87 @@ from letta.settings import settings
logger = get_logger(__name__)
def pinecone_retry(
max_attempts: int = PINECONE_MAX_RETRY_ATTEMPTS,
base_delay: float = PINECONE_RETRY_BASE_DELAY,
max_delay: float = PINECONE_RETRY_MAX_DELAY,
backoff_factor: float = PINECONE_RETRY_BACKOFF_FACTOR,
):
"""
Decorator to retry Pinecone operations with exponential backoff.
Args:
max_attempts: Maximum number of retry attempts
base_delay: Base delay in seconds for the first retry
max_delay: Maximum delay in seconds between retries
backoff_factor: Factor to increase delay after each failed attempt
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
operation_name = func.__name__
start_time = time.time()
for attempt in range(max_attempts):
try:
logger.debug(f"[Pinecone] Starting {operation_name} (attempt {attempt + 1}/{max_attempts})")
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
logger.info(f"[Pinecone] {operation_name} completed successfully in {execution_time:.2f}s")
return result
except (ServiceException, PineconeApiException) as e:
# retryable server errors
if attempt == max_attempts - 1:
execution_time = time.time() - start_time
logger.error(f"[Pinecone] {operation_name} failed after {max_attempts} attempts in {execution_time:.2f}s: {str(e)}")
raise
# calculate delay with exponential backoff and jitter
delay = min(base_delay * (backoff_factor**attempt), max_delay)
jitter = random.uniform(0, delay * 0.1) # add up to 10% jitter
total_delay = delay + jitter
logger.warning(
f"[Pinecone] {operation_name} failed (attempt {attempt + 1}/{max_attempts}): {str(e)}. Retrying in {total_delay:.2f}s"
)
await asyncio.sleep(total_delay)
except (UnauthorizedException, ForbiddenException) as e:
# non-retryable auth errors
execution_time = time.time() - start_time
logger.error(f"[Pinecone] {operation_name} failed with auth error in {execution_time:.2f}s: {str(e)}")
raise
except NotFoundException as e:
# non-retryable not found errors
execution_time = time.time() - start_time
logger.warning(f"[Pinecone] {operation_name} failed with not found error in {execution_time:.2f}s: {str(e)}")
raise
except Exception as e:
# other unexpected errors - retry once then fail
if attempt == max_attempts - 1:
execution_time = time.time() - start_time
logger.error(f"[Pinecone] {operation_name} failed after {max_attempts} attempts in {execution_time:.2f}s: {str(e)}")
raise
delay = min(base_delay * (backoff_factor**attempt), max_delay)
jitter = random.uniform(0, delay * 0.1)
total_delay = delay + jitter
logger.warning(
f"[Pinecone] {operation_name} failed with unexpected error (attempt {attempt + 1}/{max_attempts}): {str(e)}. Retrying in {total_delay:.2f}s"
)
await asyncio.sleep(total_delay)
return wrapper
return decorator
def should_use_pinecone(verbose: bool = False):
if verbose:
logger.info(
@@ -44,29 +141,42 @@ def should_use_pinecone(verbose: bool = False):
)
@pinecone_retry()
@trace_method
async def upsert_pinecone_indices():
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
for index_name in get_pinecone_indices():
indices = get_pinecone_indices()
logger.info(f"[Pinecone] Upserting {len(indices)} indices: {indices}")
for index_name in indices:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
if not await pc.has_index(index_name):
logger.info(f"[Pinecone] Creating index {index_name} with model {PINECONE_EMBEDDING_MODEL}")
await pc.create_index_for_model(
name=index_name,
cloud=PINECONE_CLOUD,
region=PINECONE_REGION,
embed=IndexEmbed(model=PINECONE_EMBEDDING_MODEL, field_map={"text": PINECONE_TEXT_FIELD_NAME}, metric=PINECONE_METRIC),
)
logger.info(f"[Pinecone] Successfully created index {index_name}")
else:
logger.debug(f"[Pinecone] Index {index_name} already exists")
def get_pinecone_indices() -> List[str]:
return [settings.pinecone_agent_index, settings.pinecone_source_index]
@pinecone_retry()
@trace_method
async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User):
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
logger.info(f"[Pinecone] Preparing to upsert {len(chunks)} chunks for file {file_id} source {source_id}")
records = []
for i, chunk in enumerate(chunks):
record = {
@@ -77,14 +187,19 @@ async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, ch
}
records.append(record)
logger.debug(f"[Pinecone] Created {len(records)} records for file {file_id}")
return await upsert_records_to_pinecone_index(records, actor)
@pinecone_retry()
@trace_method
async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
namespace = actor.organization_id
logger.info(f"[Pinecone] Deleting records for file {file_id} from index {settings.pinecone_source_index} namespace {namespace}")
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
@@ -95,48 +210,72 @@ async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
},
namespace=namespace,
)
logger.info(f"[Pinecone] Successfully deleted records for file {file_id}")
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
logger.warning(f"[Pinecone] Namespace {namespace} not found for file {file_id} and org {actor.organization_id}")
@pinecone_retry()
@trace_method
async def delete_source_records_from_pinecone_index(source_id: str, actor: User):
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
namespace = actor.organization_id
logger.info(f"[Pinecone] Deleting records for source {source_id} from index {settings.pinecone_source_index} namespace {namespace}")
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.delete(filter={"source_id": {"$eq": source_id}}, namespace=namespace)
logger.info(f"[Pinecone] Successfully deleted records for source {source_id}")
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {source_id} and {actor.organization_id}")
logger.warning(f"[Pinecone] Namespace {namespace} not found for source {source_id} and org {actor.organization_id}")
@pinecone_retry()
@trace_method
async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
logger.info(f"[Pinecone] Upserting {len(records)} records to index {settings.pinecone_source_index} for org {actor.organization_id}")
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
# Process records in batches to avoid exceeding Pinecone limits
# process records in batches to avoid exceeding pinecone limits
total_batches = (len(records) + PINECONE_MAX_BATCH_SIZE - 1) // PINECONE_MAX_BATCH_SIZE
logger.debug(f"[Pinecone] Processing {total_batches} batches of max {PINECONE_MAX_BATCH_SIZE} records each")
for i in range(0, len(records), PINECONE_MAX_BATCH_SIZE):
batch = records[i : i + PINECONE_MAX_BATCH_SIZE]
batch_num = (i // PINECONE_MAX_BATCH_SIZE) + 1
logger.debug(f"[Pinecone] Upserting batch {batch_num}/{total_batches} with {len(batch)} records")
await dense_index.upsert_records(actor.organization_id, batch)
logger.info(f"[Pinecone] Successfully upserted all {len(records)} records in {total_batches} batches")
@pinecone_retry()
@trace_method
async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]:
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
namespace = actor.organization_id
logger.info(
f"[Pinecone] Searching index {settings.pinecone_source_index} namespace {namespace} with query length {len(query)} chars, limit {limit}"
)
logger.debug(f"[Pinecone] Search filter: {filter}")
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
namespace = actor.organization_id
try:
# Search the dense index with reranking
# search the dense index with reranking
search_results = await dense_index.search(
namespace=namespace,
query={
@@ -146,17 +285,26 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any],
},
rerank={"model": "bge-reranker-v2-m3", "top_n": limit, "rank_fields": [PINECONE_TEXT_FIELD_NAME]},
)
result_count = len(search_results.get("matches", []))
logger.info(f"[Pinecone] Search completed, found {result_count} matches")
return search_results
except Exception as e:
logger.warning(f"Failed to search Pinecone namespace {actor.organization_id}: {str(e)}")
logger.warning(f"[Pinecone] Failed to search namespace {namespace}: {str(e)}")
raise e
@pinecone_retry()
@trace_method
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
namespace = actor.organization_id
logger.info(f"[Pinecone] Listing records for file {file_id} from index {settings.pinecone_source_index} namespace {namespace}")
logger.debug(f"[Pinecone] List params - limit: {limit}, pagination_token: {pagination_token}")
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
@@ -172,9 +320,14 @@ async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int =
result = []
async for ids in dense_index.list(**kwargs):
result.extend(ids)
logger.info(f"[Pinecone] Successfully listed {len(result)} records for file {file_id}")
return result
except Exception as e:
logger.warning(f"Failed to list Pinecone namespace {actor.organization_id}: {str(e)}")
logger.warning(f"[Pinecone] Failed to list records for file {file_id} in namespace {namespace}: {str(e)}")
raise e
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
logger.warning(f"[Pinecone] Namespace {namespace} not found for file {file_id} and org {actor.organization_id}")
return []

View File

@@ -1,5 +1,5 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey, Index, Integer, String, Text, UniqueConstraint, desc
from sqlalchemy.ext.asyncio import AsyncAttrs
@@ -11,10 +11,7 @@ from letta.schemas.enums import FileProcessingStatus
from letta.schemas.file import FileMetadata as PydanticFileMetadata
if TYPE_CHECKING:
from letta.orm.files_agents import FileAgent
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
from letta.orm.source import Source
pass
# TODO: Note that this is NOT organization scoped, this is potentially dangerous if we misuse this
@@ -64,18 +61,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
)
file_agents: Mapped[List["FileAgent"]] = relationship(
"FileAgent",
back_populates="file",
lazy="selectin",
cascade="all, delete-orphan",
passive_deletes=True, # ← add this
)
content: Mapped[Optional["FileContent"]] = relationship(
"FileContent",
uselist=False,

View File

@@ -12,7 +12,7 @@ from letta.schemas.block import Block as PydanticBlock
from letta.schemas.file import FileAgent as PydanticFileAgent
if TYPE_CHECKING:
from letta.orm.file import FileMetadata
pass
class FileAgent(SqlalchemyBase, OrganizationMixin):
@@ -55,6 +55,12 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
nullable=False,
doc="ID of the agent",
)
source_id: Mapped[str] = mapped_column(
String,
ForeignKey("sources.id", ondelete="CASCADE"),
nullable=False,
doc="ID of the source (denormalized from files.source_id)",
)
file_name: Mapped[str] = mapped_column(
String,
@@ -78,13 +84,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
back_populates="file_agents",
lazy="selectin",
)
file: Mapped["FileMetadata"] = relationship(
"FileMetadata",
foreign_keys=[file_id],
lazy="selectin",
back_populates="file_agents",
passive_deletes=True, # ← add this
)
# TODO: This is temporary as we figure out if we want FileBlock as a first class citizen
def to_pydantic_block(self) -> PydanticBlock:
@@ -99,8 +98,8 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
return PydanticBlock(
organization_id=self.organization_id,
value=visible_content,
label=self.file.file_name,
label=self.file_name, # use denormalized file_name instead of self.file.file_name
read_only=True,
metadata={"source_id": self.file.source_id},
metadata={"source_id": self.source_id}, # use denormalized source_id
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
)

View File

@@ -9,7 +9,6 @@ if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.agent_passage import AgentPassage
from letta.orm.block import Block
from letta.orm.file import FileMetadata
from letta.orm.group import Group
from letta.orm.identity import Identity
from letta.orm.llm_batch_item import LLMBatchItem
@@ -18,7 +17,6 @@ if TYPE_CHECKING:
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig
from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.source_passage import SourcePassage
from letta.orm.tool import Tool
from letta.orm.user import User
@@ -38,8 +36,6 @@ class Organization(SqlalchemyBase):
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
# mcp_servers: Mapped[List["MCPServer"]] = relationship("MCPServer", back_populates="organization", cascade="all, delete-orphan")
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
"SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
)

View File

@@ -49,11 +49,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from")
@declared_attr
def file(cls) -> Mapped["FileMetadata"]:
"""Relationship to file"""
return relationship("FileMetadata", back_populates="source_passages", lazy="selectin")
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="source_passages", lazy="selectin")
@@ -74,11 +69,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
{"extend_existing": True},
)
@declared_attr
def source(cls) -> Mapped["Source"]:
"""Relationship to source"""
return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True)
class AgentPassage(BasePassage, AgentMixin):
"""Passages created by agents as archival memories"""

View File

@@ -1,9 +1,8 @@
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from sqlalchemy import JSON, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm import FileMetadata
from letta.orm.custom_columns import EmbeddingConfigColumn
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
@@ -11,10 +10,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.source import Source as PydanticSource
if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
pass
class Source(SqlalchemyBase, OrganizationMixin):
@@ -34,16 +30,3 @@ class Source(SqlalchemyBase, OrganizationMixin):
instructions: Mapped[str] = mapped_column(nullable=True, doc="instructions for how to use the source")
embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.")
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship(
"Agent",
secondary="sources_agents",
back_populates="sources",
lazy="selectin",
cascade="save-update", # Only propagate save and update operations
passive_deletes=True, # Let the database handle deletions
)

View File

@@ -85,6 +85,7 @@ class FileAgent(FileAgentBase):
)
agent_id: str = Field(..., description="Unique identifier of the agent.")
file_id: str = Field(..., description="Unique identifier of the file.")
source_id: str = Field(..., description="Unique identifier of the source (denormalized from files.source_id).")
file_name: str = Field(..., description="Name of the file.")
is_open: bool = Field(True, description="True if the agent currently has the file open.")
visible_content: Optional[str] = Field(

View File

@@ -210,7 +210,7 @@ class BasicBlockMemory(Memory):
Append to the contents of core memory.
Args:
label (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited.
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
@@ -226,7 +226,7 @@ class BasicBlockMemory(Memory):
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
label (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited.
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.

View File

@@ -272,14 +272,14 @@ async def modify_agent(
@router.get("/{agent_id}/tools", response_model=list[Tool], operation_id="list_agent_tools")
def list_agent_tools(
async def list_agent_tools(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Get tools from an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.list_attached_tools_async(agent_id=agent_id, actor=actor)
@router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool")
@@ -1072,7 +1072,7 @@ async def _process_message_background(
completed_at=datetime.now(timezone.utc),
metadata={"error": str(e)},
)
await server.job_manager.update_job_by_id_async(job_id=job_id, job_update=job_update, actor=actor)
await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor)
@router.post(

View File

@@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header, Query, status
from fastapi import APIRouter, Body, Depends, Header, Query
from fastapi.exceptions import HTTPException
from starlette.requests import Request
@@ -45,12 +45,8 @@ async def create_messages_batch(
if length > max_bytes:
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
# Reject request if env var is not set
if not settings.enable_batch_job_polling:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
)
logger.warning("Batch job polling is disabled. Enable batch processing by setting LETTA_ENABLE_BATCH_JOB_POLLING to True.")
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
batch_job = BatchJob(

View File

@@ -391,8 +391,8 @@ async def get_file_metadata(
if file_metadata.source_id != source_id:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.")
if should_use_pinecone() and not file_metadata.is_processing_terminal():
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks)
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor)
logger.info(
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} ({file_metadata.file_name}) in organization {actor.organization_id}"
)
@@ -402,7 +402,7 @@ async def get_file_metadata(
file_status = file_metadata.processing_status
else:
file_status = FileProcessingStatus.COMPLETED
await server.file_manager.update_file_status(
file_metadata = await server.file_manager.update_file_status(
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
)

View File

@@ -1342,9 +1342,6 @@ class SyncServer(Server):
new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added
# rebuild system prompt and force
agent_state = await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
# update job status
job.status = JobStatus.completed
job.metadata["num_passages"] = num_passages

View File

@@ -107,6 +107,32 @@ class AgentManager:
self.identity_manager = IdentityManager()
self.file_agent_manager = FileAgentManager()
@trace_method
async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None:
"""
Validate that an agent exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
agent_id: ID of the agent to validate
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
agent_check_query = sa.text(
"""
SELECT 1 FROM agents
WHERE id = :agent_id
AND organization_id = :org_id
AND is_deleted = false
"""
)
agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id})
if not agent_exists.fetchone():
raise NoResultFound(f"Agent with ID {agent_id} not found")
@staticmethod
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
@@ -635,24 +661,24 @@ class AgentManager:
return init_messages
@trace_method
@enforce_types
@trace_method
def append_initial_message_sequence_to_in_context_messages(
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
) -> PydanticAgentState:
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
@trace_method
@enforce_types
@trace_method
async def append_initial_message_sequence_to_in_context_messages_async(
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
) -> PydanticAgentState:
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
@trace_method
@enforce_types
@trace_method
def update_agent(
self,
agent_id: str,
@@ -773,8 +799,8 @@ class AgentManager:
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def update_agent_async(
self,
agent_id: str,
@@ -1125,16 +1151,16 @@ class AgentManager:
async with db_registry.async_session() as session:
return await AgentModel.size_async(db_session=session, actor=actor)
@trace_method
@enforce_types
@trace_method
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def get_agent_by_id_async(
self,
agent_id: str,
@@ -1147,8 +1173,8 @@ class AgentManager:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
return await agent.to_pydantic_async(include_relationships=include_relationships)
@trace_method
@enforce_types
@trace_method
async def get_agents_by_ids_async(
self,
agent_ids: list[str],
@@ -1164,16 +1190,16 @@ class AgentManager:
)
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
@trace_method
@enforce_types
@trace_method
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
"""
Deletes an agent and its associated relationships.
@@ -1220,8 +1246,8 @@ class AgentManager:
else:
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
@trace_method
@enforce_types
@trace_method
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
"""
Deletes an agent and its associated relationships.
@@ -1270,8 +1296,8 @@ class AgentManager:
else:
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
@trace_method
@enforce_types
@trace_method
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@@ -1279,8 +1305,8 @@ class AgentManager:
data = schema.dump(agent)
return AgentSchema(**data)
@trace_method
@enforce_types
@trace_method
def deserialize(
self,
serialized_agent: AgentSchema,
@@ -1349,8 +1375,8 @@ class AgentManager:
# ======================================================================================================================
# Per Agent Environment Variable Management
# ======================================================================================================================
@trace_method
@enforce_types
@trace_method
def _set_environment_variables(
self,
agent_id: str,
@@ -1405,8 +1431,8 @@ class AgentManager:
# Return the updated agent state
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@@ -1422,20 +1448,20 @@ class AgentManager:
# TODO: 2) These messages are ordered from oldest to newest
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
@trace_method
@enforce_types
@trace_method
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
@trace_method
@enforce_types
@trace_method
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
@trace_method
@enforce_types
@trace_method
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
@@ -1443,8 +1469,8 @@ class AgentManager:
# TODO: This is duplicated below
# TODO: This is legacy code and should be cleaned up
# TODO: A lot of the memory "compilation" should be offset to a separate class
@trace_method
@enforce_types
@trace_method
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
@@ -1515,29 +1541,42 @@ class AgentManager:
else:
return agent_state
@trace_method
# TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish
@enforce_types
@trace_method
async def rebuild_system_prompt_async(
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None
) -> PydanticAgentState:
self,
agent_id: str,
actor: PydanticUser,
force=False,
update_timestamp=True,
tool_rules_solver: Optional[ToolRulesSolver] = None,
dry_run: bool = False,
) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
# Get the current agent state
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources"], actor=actor)
num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
num_messages, num_archival_memories, agent_state = await asyncio.gather(
num_messages_task,
num_archival_memories_task,
agent_state_task,
)
if not tool_rules_solver:
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
curr_system_message = await self.get_system_message_async(
agent_id=agent_id, actor=actor
) # this is the system + memory bank, not just the system prompt
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
if curr_system_message is None:
logger.warning(f"No system message found for agent {agent_state.id} and user {actor}")
return agent_state
return agent_state, curr_system_message, num_messages, num_archival_memories
curr_system_message_openai = curr_system_message.to_openai_dict()
@@ -1551,7 +1590,7 @@ class AgentManager:
logger.debug(
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
)
return agent_state
return agent_state, curr_system_message, num_messages, num_archival_memories
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
@@ -1561,9 +1600,6 @@ class AgentManager:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = curr_system_message.created_at
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
@@ -1582,63 +1618,67 @@ class AgentManager:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out (only if there is a diff)
message = PydanticMessage.dict_to_message(
temp_message = PydanticMessage.dict_to_message(
agent_id=agent_id,
model=agent_state.llm_config.model,
openai_message_dict={"role": "system", "content": new_system_message_str},
)
message = await self.message_manager.update_message_by_id_async(
message_id=curr_system_message.id,
message_update=MessageUpdate(**message.model_dump()),
actor=actor,
)
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
else:
return agent_state
temp_message.id = curr_system_message.id
if not dry_run:
await self.message_manager.update_message_by_id_async(
message_id=curr_system_message.id,
message_update=MessageUpdate(**temp_message.model_dump()),
actor=actor,
)
else:
curr_system_message = temp_message
return agent_state, curr_system_message, num_messages, num_archival_memories
@trace_method
@enforce_types
@trace_method
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@trace_method
@enforce_types
@trace_method
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@trace_method
@enforce_types
@trace_method
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@trace_method
@enforce_types
@trace_method
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
# TODO: How do we know this?
new_messages = [message_ids[0]] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@trace_method
@enforce_types
@trace_method
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
@trace_method
@enforce_types
@trace_method
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
message_ids += [m.id for m in messages]
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
@trace_method
@enforce_types
@trace_method
async def append_to_in_context_messages_async(
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
) -> PydanticAgentState:
@@ -1648,8 +1688,8 @@ class AgentManager:
message_ids += [m.id for m in messages]
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
@trace_method
@enforce_types
@trace_method
async def reset_messages_async(
self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False
) -> PydanticAgentState:
@@ -1712,8 +1752,8 @@ class AgentManager:
else:
return agent_state
@trace_method
@enforce_types
@trace_method
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
"""
Update internal memory object and system prompt if there have been modifications.
@@ -1756,12 +1796,12 @@ class AgentManager:
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor)
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor)
return agent_state
@trace_method
@enforce_types
@trace_method
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
# TODO: This will NOT work for new blocks/file blocks added intra-step
block_ids = [b.id for b in agent_state.memory.blocks]
@@ -1779,8 +1819,8 @@ class AgentManager:
return agent_state
@trace_method
@enforce_types
@trace_method
async def refresh_file_blocks(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
file_blocks = await self.file_agent_manager.list_files_for_agent(agent_id=agent_state.id, actor=actor, return_as_blocks=True)
agent_state.memory.file_blocks = [b for b in file_blocks if b is not None]
@@ -1789,8 +1829,8 @@ class AgentManager:
# ======================================================================================================================
# Source Management
# ======================================================================================================================
@trace_method
@enforce_types
@trace_method
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a source to an agent.
@@ -1820,15 +1860,11 @@ class AgentManager:
)
# Commit the changes
await agent.update_async(session, actor=actor)
agent = await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
# Force rebuild of system prompt so that the agent is updated with passage count
pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
return pydantic_agent
@trace_method
@enforce_types
@trace_method
def append_system_message(self, agent_id: str, content: str, actor: PydanticUser):
# get the agent
@@ -1840,8 +1876,8 @@ class AgentManager:
# update agent in-context message IDs
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
@trace_method
@enforce_types
@trace_method
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
# get the agent
@@ -1853,28 +1889,8 @@ class AgentManager:
# update agent in-context message IDs
await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor)
@trace_method
@enforce_types
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
Args:
agent_id: ID of the agent to list sources for
actor: User performing the action
Returns:
List[str]: List of source IDs attached to the agent
"""
with db_registry.session() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
@trace_method
@enforce_types
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
@@ -1885,44 +1901,34 @@ class AgentManager:
Returns:
List[str]: List of source IDs attached to the agent
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
async with db_registry.async_session() as session:
# Verify agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Validate agent exists and user has access
await self._validate_agent_exists_async(session, agent_id, actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
# Use raw SQL to efficiently fetch sources - much faster than lazy loading
# Fast query without relationship loading
query = (
select(SourceModel)
.join(SourcesAgents, SourceModel.id == SourcesAgents.source_id)
.where(
SourcesAgents.agent_id == agent_id,
SourceModel.organization_id == actor.organization_id,
SourceModel.is_deleted == False,
)
.order_by(SourceModel.created_at.desc(), SourceModel.id)
)
result = await session.execute(query)
sources = result.scalars().all()
return [source.to_pydantic() for source in sources]
@trace_method
@enforce_types
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a source from an agent.
Args:
agent_id: ID of the agent to detach the source from
source_id: ID of the source to detach
actor: User performing the action
"""
with db_registry.session() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Remove the source from the relationship
remaining_sources = [s for s in agent.sources if s.id != source_id]
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
# Update the sources relationship
agent.sources = remaining_sources
# Commit the changes
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a source from an agent.
@@ -1931,29 +1937,36 @@ class AgentManager:
agent_id: ID of the agent to detach the source from
source_id: ID of the source to detach
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
async with db_registry.async_session() as session:
# Verify agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Validate agent exists and user has access
await self._validate_agent_exists_async(session, agent_id, actor)
# Remove the source from the relationship
remaining_sources = [s for s in agent.sources if s.id != source_id]
# Check if the source is actually attached to this agent using junction table
attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
attachment_result = await session.execute(attachment_check_query)
attachment = attachment_result.scalar_one_or_none()
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
if not attachment:
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
else:
# Delete the association directly from the junction table
delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
await session.execute(delete_query)
await session.commit()
# Update the sources relationship
agent.sources = remaining_sources
# Commit the changes
await agent.update_async(session, actor=actor)
# Get agent without loading relationships for return value
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
return await agent.to_pydantic_async()
# ======================================================================================================================
# Block management
# ======================================================================================================================
@trace_method
@enforce_types
@trace_method
def get_block_with_label(
self,
agent_id: str,
@@ -1968,8 +1981,8 @@ class AgentManager:
return block.to_pydantic()
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
@trace_method
@enforce_types
@trace_method
async def get_block_with_label_async(
self,
agent_id: str,
@@ -1984,8 +1997,8 @@ class AgentManager:
return block.to_pydantic()
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
@trace_method
@enforce_types
@trace_method
async def modify_block_by_label_async(
self,
agent_id: str,
@@ -2012,8 +2025,8 @@ class AgentManager:
await block.update_async(session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
def update_block_with_label(
self,
agent_id: str,
@@ -2037,8 +2050,8 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
with db_registry.session() as session:
@@ -2067,8 +2080,8 @@ class AgentManager:
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
async with db_registry.async_session() as session:
@@ -2103,8 +2116,8 @@ class AgentManager:
return await agent.to_pydantic_async()
@trace_method
@enforce_types
@trace_method
def detach_block(
self,
agent_id: str,
@@ -2124,8 +2137,8 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def detach_block_async(
self,
agent_id: str,
@@ -2145,8 +2158,8 @@ class AgentManager:
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
@trace_method
@enforce_types
@trace_method
def detach_block_with_label(
self,
agent_id: str,
@@ -2170,8 +2183,8 @@ class AgentManager:
# Passage Management
# ======================================================================================================================
@trace_method
@enforce_types
@trace_method
def list_passages(
self,
actor: PydanticUser,
@@ -2231,8 +2244,8 @@ class AgentManager:
return [p.to_pydantic() for p in passages]
@trace_method
@enforce_types
@trace_method
async def list_passages_async(
self,
actor: PydanticUser,
@@ -2292,8 +2305,8 @@ class AgentManager:
return [p.to_pydantic() for p in passages]
@trace_method
@enforce_types
@trace_method
async def list_source_passages_async(
self,
actor: PydanticUser,
@@ -2340,8 +2353,8 @@ class AgentManager:
# Convert to Pydantic models
return [p.to_pydantic() for p in passages]
@trace_method
@enforce_types
@trace_method
async def list_agent_passages_async(
self,
actor: PydanticUser,
@@ -2384,8 +2397,8 @@ class AgentManager:
# Convert to Pydantic models
return [p.to_pydantic() for p in passages]
@trace_method
@enforce_types
@trace_method
def passage_size(
self,
actor: PydanticUser,
@@ -2465,8 +2478,8 @@ class AgentManager:
# ======================================================================================================================
# Tool Management
# ======================================================================================================================
@trace_method
@enforce_types
@trace_method
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a tool to an agent.
@@ -2501,8 +2514,8 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a tool to an agent.
@@ -2537,8 +2550,8 @@ class AgentManager:
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
@trace_method
@enforce_types
@trace_method
async def attach_missing_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches missing core file tools to an agent.
@@ -2569,8 +2582,8 @@ class AgentManager:
return agent_state
@trace_method
@enforce_types
@trace_method
async def detach_all_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
"""
Detach all core file tools from an agent.
@@ -2596,8 +2609,8 @@ class AgentManager:
return agent_state
@trace_method
@enforce_types
@trace_method
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a tool from an agent.
@@ -2630,8 +2643,8 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a tool from an agent.
@@ -2664,8 +2677,8 @@ class AgentManager:
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
@trace_method
@enforce_types
@trace_method
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
"""
List all tools attached to an agent.
@@ -2681,11 +2694,40 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return [tool.to_pydantic() for tool in agent.tools]
@enforce_types
@trace_method
async def list_attached_tools_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
"""
List all tools attached to an agent (async version with optimized performance).
Uses direct SQL queries to avoid SqlAlchemyBase overhead.
Args:
agent_id: ID of the agent to list tools for.
actor: User performing the action.
Returns:
List[PydanticTool]: List of tools attached to the agent.
"""
async with db_registry.async_session() as session:
# lightweight check for agent access
await self._validate_agent_exists_async(session, agent_id, actor)
# direct query for tools via join - much more performant
query = (
select(ToolModel)
.join(ToolsAgents, ToolModel.id == ToolsAgents.tool_id)
.where(ToolsAgents.agent_id == agent_id, ToolModel.organization_id == actor.organization_id)
)
result = await session.execute(query)
tools = result.scalars().all()
return [tool.to_pydantic() for tool in tools]
# ======================================================================================================================
# Tag Management
# ======================================================================================================================
@trace_method
@enforce_types
@trace_method
def list_tags(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
) -> List[str]:
@@ -2719,8 +2761,8 @@ class AgentManager:
results = [tag[0] for tag in query.all()]
return results
@trace_method
@enforce_types
@trace_method
async def list_tags_async(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
) -> List[str]:
@@ -2759,8 +2801,11 @@ class AgentManager:
results = [row[0] for row in result.all()]
return results
@trace_method
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
agent_id=agent_id, actor=actor, force=True, dry_run=True
)
calculator = ContextWindowCalculator()
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION" and agent_state.llm_config.model_endpoint_type == "anthropic":
@@ -2776,5 +2821,7 @@ class AgentManager:
actor=actor,
token_counter=token_counter,
message_manager=self.message_manager,
passage_manager=self.passage_manager,
system_message_compiled=system_message,
num_archival_memories=num_archival_memories,
num_messages=num_messages,
)

View File

@@ -23,8 +23,8 @@ logger = get_logger(__name__)
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
@trace_method
@enforce_types
@trace_method
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
db_block = self.get_block_by_id(block.id, actor)
@@ -38,8 +38,8 @@ class BlockManager:
block.create(session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
db_block = await self.get_block_by_id_async(block.id, actor)
@@ -53,8 +53,8 @@ class BlockManager:
await block.create_async(session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
"""
Batch-create multiple Blocks in one transaction for better performance.
@@ -77,8 +77,8 @@ class BlockManager:
# Convert back to Pydantic
return [m.to_pydantic() for m in created_models]
@trace_method
@enforce_types
@trace_method
async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
"""
Batch-create multiple Blocks in one transaction for better performance.
@@ -101,8 +101,8 @@ class BlockManager:
# Convert back to Pydantic
return [m.to_pydantic() for m in created_models]
@trace_method
@enforce_types
@trace_method
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# Safety check for block
@@ -117,8 +117,8 @@ class BlockManager:
block.update(db_session=session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# Safety check for block
@@ -133,8 +133,8 @@ class BlockManager:
await block.update_async(db_session=session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
"""Delete a block by its ID."""
with db_registry.session() as session:
@@ -142,8 +142,8 @@ class BlockManager:
block.hard_delete(db_session=session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
"""Delete a block by its ID."""
async with db_registry.async_session() as session:
@@ -151,8 +151,8 @@ class BlockManager:
await block.hard_delete_async(db_session=session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def get_blocks_async(
self,
actor: PydanticUser,
@@ -214,8 +214,8 @@ class BlockManager:
return [block.to_pydantic() for block in blocks]
@trace_method
@enforce_types
@trace_method
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
with db_registry.session() as session:
@@ -225,8 +225,8 @@ class BlockManager:
except NoResultFound:
return None
@trace_method
@enforce_types
@trace_method
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
async with db_registry.async_session() as session:
@@ -236,8 +236,8 @@ class BlockManager:
except NoResultFound:
return None
@trace_method
@enforce_types
@trace_method
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
from sqlalchemy import select
@@ -284,8 +284,8 @@ class BlockManager:
return pydantic_blocks
@trace_method
@enforce_types
@trace_method
async def get_agents_for_block_async(
self,
block_id: str,
@@ -301,8 +301,8 @@ class BlockManager:
agents = await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents_orm])
return agents
@trace_method
@enforce_types
@trace_method
async def size_async(self, actor: PydanticUser) -> int:
"""
Get the total count of blocks for the given user.
@@ -312,8 +312,8 @@ class BlockManager:
# Block History Functions
@trace_method
@enforce_types
@trace_method
def checkpoint_block(
self,
block_id: str,
@@ -416,8 +416,8 @@ class BlockManager:
updated_block = block.update(db_session=session, actor=actor, no_commit=True)
return updated_block
@trace_method
@enforce_types
@trace_method
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
"""
Move the block to the immediately previous checkpoint in BlockHistory.
@@ -459,8 +459,8 @@ class BlockManager:
session.commit()
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
"""
Move the block to the next checkpoint if it exists.
@@ -498,8 +498,8 @@ class BlockManager:
session.commit()
return block.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def bulk_update_block_values_async(
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
) -> Optional[List[PydanticBlock]]:

View File

@@ -4,11 +4,14 @@ from typing import Any, List, Optional, Tuple
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.memory import ContextWindowOverview
from letta.schemas.message import Message
from letta.schemas.user import User as PydanticUser
from letta.services.context_window_calculator.token_counter import TokenCounter
from letta.services.message_manager import MessageManager
logger = get_logger(__name__)
@@ -56,16 +59,18 @@ class ContextWindowCalculator:
return None, 1
async def calculate_context_window(
self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
self,
agent_state: AgentState,
actor: PydanticUser,
token_counter: TokenCounter,
message_manager: MessageManager,
system_message_compiled: Message,
num_archival_memories: int,
num_messages: int,
) -> ContextWindowOverview:
"""Calculate context window information using the provided token counter"""
# Fetch data concurrently
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id),
message_manager.size_async(actor=actor, agent_id=agent_state.id),
)
messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids[1:], actor=actor)
in_context_messages = [system_message_compiled] + messages
# Convert messages to appropriate format
converted_messages = token_counter.convert_messages(in_context_messages)
@@ -128,8 +133,8 @@ class ContextWindowCalculator:
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(in_context_messages),
num_archival_memory=passage_manager_size,
num_recall_memory=message_manager_size,
num_archival_memory=num_archival_memories,
num_recall_memory=num_messages,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
external_memory_summary=external_memory_summary,
# top-level information

View File

@@ -1,7 +1,11 @@
import hashlib
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.anthropic_client import AnthropicClient
from letta.otel.tracing import trace_method
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.utils import count_tokens
@@ -33,16 +37,34 @@ class AnthropicTokenCounter(TokenCounter):
self.client = anthropic_client
self.model = model
@trace_method
@async_redis_cache(
key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
@trace_method
@async_redis_cache(
key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
return await self.client.count_tokens(model=self.model, messages=messages)
@trace_method
@async_redis_cache(
key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
@@ -58,11 +80,23 @@ class TiktokenCounter(TokenCounter):
def __init__(self, model: str):
self.model = model
@trace_method
@async_redis_cache(
key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return count_tokens(text)
@trace_method
@async_redis_cache(
key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
@@ -70,6 +104,12 @@ class TiktokenCounter(TokenCounter):
return num_tokens_from_messages(messages=messages, model=self.model)
@trace_method
@async_redis_cache(
key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0

View File

@@ -40,6 +40,10 @@ class LineChunker:
def _chunk_by_lines(self, text: str, preserve_indentation: bool = False) -> List[str]:
"""Traditional line-based chunking for code and structured data"""
# early stop, can happen if the there's nothing on a specific file
if not text:
return []
lines = []
for line in text.splitlines():
if preserve_indentation:
@@ -57,6 +61,10 @@ class LineChunker:
def _chunk_by_sentences(self, text: str) -> List[str]:
"""Sentence-based chunking for documentation and markup"""
# early stop, can happen if the there's nothing on a specific file
if not text:
return []
# Simple sentence splitting on periods, exclamation marks, and question marks
# followed by whitespace or end of string
sentence_pattern = r"(?<=[.!?])\s+(?=[A-Z])"
@@ -75,6 +83,10 @@ class LineChunker:
def _chunk_by_characters(self, text: str, target_line_length: int = 100) -> List[str]:
"""Character-based wrapping for prose text"""
# early stop, can happen if the there's nothing on a specific file
if not text:
return []
words = text.split()
lines = []
current_line = []
@@ -110,6 +122,11 @@ class LineChunker:
strategy = self._determine_chunking_strategy(file_metadata)
text = file_metadata.content
# early stop, can happen if the there's nothing on a specific file
if not text:
logger.warning(f"File ({file_metadata}) has no content")
return []
# Apply the appropriate chunking strategy
if strategy == ChunkingStrategy.DOCUMENTATION:
content_lines = self._chunk_by_sentences(text)

View File

@@ -25,7 +25,6 @@ class OpenAIEmbedder(BaseEmbedder):
else EmbeddingConfig.default_config(model_name="letta")
)
self.embedding_config = embedding_config or self.default_embedding_config
self.max_concurrent_requests = 20
# TODO: Unify to global OpenAI client
self.client: OpenAIClient = cast(
@@ -48,9 +47,55 @@ class OpenAIEmbedder(BaseEmbedder):
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
},
)
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
try:
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
except Exception as e:
# if it's a token limit error and we can split, do it
if self._is_token_limit_error(e) and len(batch) > 1:
logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying")
log_event(
"embedder.batch_split_retry",
{
"original_batch_size": len(batch),
"error": str(e),
"split_size": len(batch) // 2,
},
)
# split batch in half
mid = len(batch) // 2
batch1 = batch[:mid]
batch1_indices = batch_indices[:mid]
batch2 = batch[mid:]
batch2_indices = batch_indices[mid:]
# retry with smaller batches
result1 = await self._embed_batch(batch1, batch1_indices)
result2 = await self._embed_batch(batch2, batch2_indices)
return result1 + result2
else:
# re-raise for other errors or if batch size is already 1
raise
def _is_token_limit_error(self, error: Exception) -> bool:
"""Check if the error is due to token limit exceeded"""
# convert to string and check for token limit patterns
error_str = str(error).lower()
# TODO: This is quite brittle, works for now
# check for the specific patterns we see in token limit errors
is_token_limit = (
"max_tokens_per_request" in error_str
or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str)
or "token limit" in error_str
or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str)
)
return is_token_limit
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
@@ -100,7 +145,7 @@ class OpenAIEmbedder(BaseEmbedder):
log_event(
"embedder.concurrent_processing_started",
{"concurrent_tasks": len(tasks), "max_concurrent_requests": self.max_concurrent_requests},
{"concurrent_tasks": len(tasks)},
)
results = await asyncio.gather(*tasks)
log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)})

View File

@@ -29,6 +29,7 @@ class FileAgentManager:
agent_id: str,
file_id: str,
file_name: str,
source_id: str,
actor: PydanticUser,
is_open: bool = True,
visible_content: Optional[str] = None,
@@ -47,7 +48,12 @@ class FileAgentManager:
if is_open:
# Use the efficient LRU + open method
closed_files, was_already_open = await self.enforce_max_open_files_and_open(
agent_id=agent_id, file_id=file_id, file_name=file_name, actor=actor, visible_content=visible_content or ""
agent_id=agent_id,
file_id=file_id,
file_name=file_name,
source_id=source_id,
actor=actor,
visible_content=visible_content or "",
)
# Get the updated file agent to return
@@ -85,6 +91,7 @@ class FileAgentManager:
agent_id=agent_id,
file_id=file_id,
file_name=file_name,
source_id=source_id,
organization_id=actor.organization_id,
is_open=is_open,
visible_content=visible_content,
@@ -327,7 +334,7 @@ class FileAgentManager:
@enforce_types
@trace_method
async def enforce_max_open_files_and_open(
self, *, agent_id: str, file_id: str, file_name: str, actor: PydanticUser, visible_content: str
self, *, agent_id: str, file_id: str, file_name: str, source_id: str, actor: PydanticUser, visible_content: str
) -> tuple[List[str], bool]:
"""
Efficiently handle LRU eviction and file opening in a single transaction.
@@ -336,6 +343,7 @@ class FileAgentManager:
agent_id: ID of the agent
file_id: ID of the file to open
file_name: Name of the file to open
source_id: ID of the source (denormalized from files.source_id)
actor: User performing the action
visible_content: Content to set for the opened file
@@ -418,6 +426,7 @@ class FileAgentManager:
agent_id=agent_id,
file_id=file_id,
file_name=file_name,
source_id=source_id,
organization_id=actor.organization_id,
is_open=True,
visible_content=visible_content,
@@ -516,6 +525,7 @@ class FileAgentManager:
agent_id=agent_id,
file_id=meta.id,
file_name=meta.file_name,
source_id=meta.source_id,
organization_id=actor.organization_id,
is_open=is_now_open,
visible_content=vc,

View File

@@ -19,8 +19,8 @@ from letta.utils import enforce_types
class GroupManager:
@trace_method
@enforce_types
@trace_method
def list_groups(
self,
actor: PydanticUser,
@@ -45,22 +45,22 @@ class GroupManager:
)
return [group.to_pydantic() for group in groups]
@trace_method
@enforce_types
@trace_method
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
with db_registry.session() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
return group.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
async with db_registry.async_session() as session:
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
return group.to_pydantic()
@trace_method
@enforce_types
@trace_method
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
with db_registry.session() as session:
new_group = GroupModel()
@@ -150,8 +150,8 @@ class GroupManager:
await new_group.create_async(session, actor=actor)
return new_group.to_pydantic()
@trace_method
@enforce_types
@trace_method
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
async with db_registry.async_session() as session:
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
@@ -213,16 +213,16 @@ class GroupManager:
await group.update_async(session, actor=actor)
return group.to_pydantic()
@trace_method
@enforce_types
@trace_method
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
with db_registry.session() as session:
# Retrieve the agent
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
group.hard_delete(session)
@trace_method
@enforce_types
@trace_method
def list_group_messages(
self,
actor: PydanticUser,
@@ -258,8 +258,8 @@ class GroupManager:
return messages
@trace_method
@enforce_types
@trace_method
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
with db_registry.session() as session:
# Ensure group is loadable by user
@@ -272,8 +272,8 @@ class GroupManager:
session.commit()
@trace_method
@enforce_types
@trace_method
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
with db_registry.session() as session:
# Ensure group is loadable by user
@@ -284,8 +284,8 @@ class GroupManager:
group.update(session, actor=actor)
return group.turns_counter
@trace_method
@enforce_types
@trace_method
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
@@ -309,8 +309,8 @@ class GroupManager:
return prev_last_processed_message_id
@trace_method
@enforce_types
@trace_method
async def get_last_processed_message_id_and_update_async(
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
) -> str:

View File

@@ -1,8 +1,12 @@
import asyncio
from typing import List, Optional
from sqlalchemy import select
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.source import Source as SourceModel
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.source import Source as PydanticSource
@@ -104,9 +108,21 @@ class SourceManager:
# Verify source exists and user has permission to access it
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# The agents relationship is already loaded due to lazy="selectin" in the Source model
# and will be properly filtered by organization_id due to the OrganizationMixin
agents_orm = source.agents
# Use junction table query instead of relationship to avoid performance issues
query = (
select(AgentModel)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id if actor else True,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
result = await session.execute(query)
agents_orm = result.scalars().all()
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons

View File

@@ -188,7 +188,7 @@ class LettaCoreToolExecutor(ToolExecutor):
Append to the contents of core memory.
Args:
label (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited.
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
@@ -214,7 +214,7 @@ class LettaCoreToolExecutor(ToolExecutor):
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
label (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited.
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.

View File

@@ -180,7 +180,12 @@ class LettaFileToolExecutor(ToolExecutor):
# Handle LRU eviction and file opening
closed_files, was_already_open = await self.files_agents_manager.enforce_max_open_files_and_open(
agent_id=agent_state.id, file_id=file_id, file_name=file_name, actor=self.actor, visible_content=visible_content
agent_id=agent_state.id,
file_id=file_id,
file_name=file_name,
source_id=file.source_id,
actor=self.actor,
visible_content=visible_content,
)
opened_files.append(file_name)

6
poetry.lock generated
View File

@@ -3591,14 +3591,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
[[package]]
name = "letta-client"
version = "0.1.198"
version = "0.1.197"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
groups = ["main"]
files = [
{file = "letta_client-0.1.198-py3-none-any.whl", hash = "sha256:08bbc238b128da2552b2a6e54feb3294794b5586e0962ce0bb95bb525109f58f"},
{file = "letta_client-0.1.198.tar.gz", hash = "sha256:990c9132423e2955d9c7f7549e5064b2366616232d270e5927788cddba4ef9da"},
{file = "letta_client-0.1.197-py3-none-any.whl", hash = "sha256:b01eab01ff87a34e79622cd8d3a2f3da56b6bd730312a268d968576671f6cce0"},
{file = "letta_client-0.1.197.tar.gz", hash = "sha256:579571623ccec81422087cb7957d5f580fa0b7a53f8495596e5c56d0213220d8"},
]
[package.dependencies]

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.8.13"
version = "0.8.14"
packages = [
{include = "letta"},
]
@@ -116,7 +116,6 @@ google = ["google-genai"]
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
ipykernel = "^6.29.5"

View File

@@ -306,7 +306,7 @@
"properties": {
"label": {
"type": "string",
"description": "Section of the memory to be edited (persona or human)."
"description": "Section of the memory to be edited."
},
"content": {
"type": "string",
@@ -343,7 +343,7 @@
"properties": {
"label": {
"type": "string",
"description": "Section of the memory to be edited (persona or human)."
"description": "Section of the memory to be edited."
},
"old_content": {
"type": "string",

View File

@@ -237,7 +237,7 @@ def validate_context_window_overview(
# 16. Check attached file is visible
if attached_file:
assert attached_file.visible_content in overview.core_memory
assert attached_file.visible_content in overview.core_memory, "File must be attached in core memory"
assert '<file status="open"' in overview.core_memory
assert "</file>" in overview.core_memory

View File

@@ -0,0 +1,219 @@
from unittest.mock import AsyncMock, Mock, patch
import openai
import pytest
from letta.errors import ErrorCode, LLMBadRequestError
from letta.schemas.embedding_config import EmbeddingConfig
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
class TestOpenAIEmbedder:
"""Test suite for OpenAI embedder functionality"""
@pytest.fixture
def mock_user(self):
"""Create a mock user for testing"""
user = Mock()
user.organization_id = "test_org_id"
return user
@pytest.fixture
def embedding_config(self):
"""Create a test embedding config"""
return EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=3, # small dimension for testing
embedding_chunk_size=300,
batch_size=2, # small batch size for testing
)
@pytest.fixture
def embedder(self, embedding_config):
"""Create OpenAI embedder with test config"""
with patch("letta.services.file_processor.embedder.openai_embedder.LLMClient.create") as mock_create:
mock_client = Mock()
mock_client.handle_llm_error = Mock()
mock_create.return_value = mock_client
embedder = OpenAIEmbedder(embedding_config)
embedder.client = mock_client
return embedder
@pytest.mark.asyncio
async def test_successful_embedding_generation(self, embedder, mock_user):
"""Test successful embedding generation for normal cases"""
# mock successful embedding response
mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
embedder.client.request_embeddings = AsyncMock(return_value=mock_embeddings)
chunks = ["chunk 1", "chunk 2"]
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert len(passages) == 2
assert passages[0].text == "chunk 1"
assert passages[1].text == "chunk 2"
# embeddings are padded to MAX_EMBEDDING_DIM, so check first 3 values
assert passages[0].embedding[:3] == [0.1, 0.2, 0.3]
assert passages[1].embedding[:3] == [0.4, 0.5, 0.6]
assert passages[0].file_id == file_id
assert passages[0].source_id == source_id
@pytest.mark.asyncio
async def test_token_limit_retry_splits_batch(self, embedder, mock_user):
"""Test that token limit errors trigger batch splitting and retry"""
# create a mock token limit error
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
# first call fails with token limit, subsequent calls succeed
call_count = 0
async def mock_request_embeddings(inputs, embedding_config):
nonlocal call_count
call_count += 1
if call_count == 1 and len(inputs) == 4: # first call with full batch
raise token_limit_error
elif len(inputs) == 2: # split batches succeed
return [[0.1, 0.2], [0.3, 0.4]] if call_count == 2 else [[0.5, 0.6], [0.7, 0.8]]
else:
return [[0.1, 0.2]] * len(inputs)
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
# should still get all 4 passages despite the retry
assert len(passages) == 4
assert all(len(p.embedding) == 4096 for p in passages) # padded to MAX_EMBEDDING_DIM
# verify multiple calls were made (original + retries)
assert call_count >= 2
@pytest.mark.asyncio
async def test_token_limit_error_detection(self, embedder):
"""Test various token limit error detection patterns"""
# test openai BadRequestError with proper structure
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
openai_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
assert embedder._is_token_limit_error(openai_error) is True
# test error with message but no code
mock_error_body_no_code = {"error": {"message": "max_tokens_per_request exceeded"}}
openai_error_no_code = openai.BadRequestError(
message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body_no_code
)
assert embedder._is_token_limit_error(openai_error_no_code) is True
# test fallback string detection
generic_error = Exception("Requested 100000 tokens, max 50000 tokens per request")
assert embedder._is_token_limit_error(generic_error) is True
# test non-token errors
other_error = Exception("Some other error")
assert embedder._is_token_limit_error(other_error) is False
auth_error = openai.AuthenticationError(
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
)
assert embedder._is_token_limit_error(auth_error) is False
@pytest.mark.asyncio
async def test_non_token_error_handling(self, embedder, mock_user):
"""Test that non-token errors are properly handled and re-raised"""
# create a non-token error
auth_error = openai.AuthenticationError(
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
)
# mock handle_llm_error to return a standardized error
handled_error = LLMBadRequestError(message="Handled error", code=ErrorCode.UNAUTHENTICATED)
embedder.client.handle_llm_error.return_value = handled_error
embedder.client.request_embeddings = AsyncMock(side_effect=auth_error)
chunks = ["chunk 1"]
file_id = "test_file"
source_id = "test_source"
with pytest.raises(LLMBadRequestError) as exc_info:
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert exc_info.value == handled_error
embedder.client.handle_llm_error.assert_called_once_with(auth_error)
@pytest.mark.asyncio
async def test_single_item_batch_no_retry(self, embedder, mock_user):
"""Test that single-item batches don't retry on token limit errors"""
# create a token limit error
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
handled_error = LLMBadRequestError(message="Handled token limit error", code=ErrorCode.INVALID_ARGUMENT)
embedder.client.handle_llm_error.return_value = handled_error
embedder.client.request_embeddings = AsyncMock(side_effect=token_limit_error)
chunks = ["very long chunk that exceeds token limit"]
file_id = "test_file"
source_id = "test_source"
with pytest.raises(LLMBadRequestError) as exc_info:
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert exc_info.value == handled_error
embedder.client.handle_llm_error.assert_called_once_with(token_limit_error)
@pytest.mark.asyncio
async def test_empty_chunks_handling(self, embedder, mock_user):
"""Test handling of empty chunks list"""
chunks = []
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
assert passages == []
# should not call request_embeddings for empty input
embedder.client.request_embeddings.assert_not_called()
@pytest.mark.asyncio
async def test_embedding_order_preservation(self, embedder, mock_user):
"""Test that embedding order is preserved even with retries"""
# set up embedder to split batches (batch_size=2)
embedder.embedding_config.batch_size = 2
# mock responses for each batch
async def mock_request_embeddings(inputs, embedding_config):
# return embeddings that correspond to input order
if inputs == ["chunk 1", "chunk 2"]:
return [[0.1, 0.1], [0.2, 0.2]]
elif inputs == ["chunk 3", "chunk 4"]:
return [[0.3, 0.3], [0.4, 0.4]]
else:
return [[0.1, 0.1]] * len(inputs)
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
file_id = "test_file"
source_id = "test_source"
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
# verify order is preserved
assert len(passages) == 4
assert passages[0].text == "chunk 1"
assert passages[0].embedding[:2] == [0.1, 0.1] # check first 2 values before padding
assert passages[1].text == "chunk 2"
assert passages[1].embedding[:2] == [0.2, 0.2]
assert passages[2].text == "chunk 3"
assert passages[2].embedding[:2] == [0.3, 0.3]
assert passages[3].text == "chunk 4"
assert passages[3].embedding[:2] == [0.4, 0.4]

View File

@@ -673,6 +673,7 @@ async def file_attachment(server, default_user, sarah_agent, default_file):
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="initial",
)
@@ -903,6 +904,7 @@ async def test_get_context_window_basic(
agent_id=created_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="hello",
)
@@ -7221,6 +7223,7 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="hello",
)
@@ -7243,6 +7246,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
visible_content="first",
)
@@ -7252,6 +7256,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
is_open=False,
visible_content="second",
@@ -7326,15 +7331,28 @@ async def test_list_files_and_agents(
):
# default_file ↔ charles (open)
await server.file_agent_manager.attach_file(
agent_id=charles_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user
agent_id=charles_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
)
# default_file ↔ sarah (open)
await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
)
# another_file ↔ sarah (closed)
await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id, file_id=another_file.id, file_name=another_file.file_name, actor=default_user, is_open=False
agent_id=sarah_agent.id,
file_id=another_file.id,
file_name=another_file.file_name,
source_id=another_file.source_id,
actor=default_user,
is_open=False,
)
files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user)
@@ -7384,6 +7402,7 @@ async def test_org_scoping(
agent_id=sarah_agent.id,
file_id=default_file.id,
file_name=default_file.file_name,
source_id=default_file.source_id,
actor=default_user,
)
@@ -7420,6 +7439,7 @@ async def test_mark_access_bulk(server, default_user, sarah_agent, default_sourc
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"content for {file.file_name}",
)
@@ -7478,6 +7498,7 @@ async def test_lru_eviction_on_attach(server, default_user, sarah_agent, default
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"content for {file.file_name}",
)
@@ -7530,6 +7551,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
agent_id=sarah_agent.id,
file_id=files[i].id,
file_name=files[i].file_name,
source_id=files[i].source_id,
actor=default_user,
visible_content=f"content for {files[i].file_name}",
)
@@ -7539,6 +7561,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
agent_id=sarah_agent.id,
file_id=files[-1].id,
file_name=files[-1].file_name,
source_id=files[-1].source_id,
actor=default_user,
is_open=False,
visible_content=f"content for {files[-1].file_name}",
@@ -7555,7 +7578,12 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
# Now "open" the last file using the efficient method
closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open(
agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content"
agent_id=sarah_agent.id,
file_id=files[-1].id,
file_name=files[-1].file_name,
source_id=files[-1].source_id,
actor=default_user,
visible_content="updated content",
)
# Should have closed 1 file (the oldest one)
@@ -7603,6 +7631,7 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"content for {file.file_name}",
)
@@ -7617,7 +7646,12 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa
# "Reopen" the last file (which is already open)
closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open(
agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content"
agent_id=sarah_agent.id,
file_id=files[-1].id,
file_name=files[-1].file_name,
source_id=files[-1].source_id,
actor=default_user,
visible_content="updated content",
)
# Should not have closed any files since we're within the limit
@@ -7645,7 +7679,12 @@ async def test_last_accessed_at_updates_correctly(server, default_user, sarah_ag
file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text="test content")
file_agent, closed_files = await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, actor=default_user, visible_content="initial content"
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content="initial content",
)
initial_time = file_agent.last_accessed_at
@@ -7777,6 +7816,7 @@ async def test_attach_files_bulk_lru_eviction(server, default_user, sarah_agent,
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
visible_content=f"existing content {i}",
)
@@ -7842,6 +7882,7 @@ async def test_attach_files_bulk_mixed_existing_new(server, default_user, sarah_
agent_id=sarah_agent.id,
file_id=existing_file.id,
file_name=existing_file.file_name,
source_id=existing_file.source_id,
actor=default_user,
visible_content="old content",
is_open=False, # Start as closed

View File

@@ -1033,3 +1033,36 @@ def test_preview_payload(client: LettaSDKClient, agent):
assert payload["user"].startswith("user-")
print(payload)
def test_agent_tools_list(client: LettaSDKClient):
"""Test the optimized agent tools list endpoint for correctness."""
# Create a test agent
agent_state = client.agents.create(
name="test_agent_tools_list",
memory_blocks=[
CreateBlock(
label="persona",
value="You are a helpful assistant.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
include_base_tools=True,
)
try:
# Test basic functionality
tools = client.agents.tools.list(agent_id=agent_state.id)
assert len(tools) > 0, "Agent should have base tools attached"
# Verify tool objects have expected attributes
for tool in tools:
assert hasattr(tool, "id"), "Tool should have id attribute"
assert hasattr(tool, "name"), "Tool should have name attribute"
assert tool.id is not None, "Tool id should not be None"
assert tool.name is not None, "Tool name should not be None"
finally:
# Clean up
client.agents.delete(agent_id=agent_state.id)

View File

@@ -464,9 +464,8 @@ def test_line_chunker_edge_case_empty_file():
file = FileMetadata(file_name="empty.py", source_id="test_source", content="")
chunker = LineChunker()
# Test requesting lines from empty file
with pytest.raises(ValueError, match="File empty.py has only 0 lines, but requested offset 1 is out of range"):
chunker.chunk_text(file, start=0, end=1, validate_range=True)
# no error
chunker.chunk_text(file, start=0, end=1, validate_range=True)
def test_line_chunker_edge_case_single_line():