chore: bump version 0.8.14 (#2720)
Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: jnjpng <jin@letta.com> Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local> Co-authored-by: cpacker <packercharles@gmail.com> Co-authored-by: Shubham Naik <shub@letta.com> Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Kevin Lin <klin5061@gmail.com>
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
"""Write source_id directly to files agents
|
||||
|
||||
Revision ID: 495f3f474131
|
||||
Revises: 47d2277e530d
|
||||
Create Date: 2025-07-10 17:14:45.154738
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "495f3f474131"
|
||||
down_revision: Union[str, None] = "47d2277e530d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Step 1: Add the column as nullable first
|
||||
op.add_column("files_agents", sa.Column("source_id", sa.String(), nullable=True))
|
||||
|
||||
# Step 2: Backfill source_id from files table
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE files_agents
|
||||
SET source_id = files.source_id
|
||||
FROM files
|
||||
WHERE files_agents.file_id = files.id
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Make the column NOT NULL now that it's populated
|
||||
op.alter_column("files_agents", "source_id", nullable=False)
|
||||
|
||||
# Step 4: Add the foreign key constraint
|
||||
op.create_foreign_key(None, "files_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "files_agents", type_="foreignkey")
|
||||
op.drop_column("files_agents", "source_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.8.13"
|
||||
__version__ = "0.8.14"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
@@ -372,3 +372,9 @@ PINECONE_METRIC = "cosine"
|
||||
PINECONE_CLOUD = "aws"
|
||||
PINECONE_REGION = "us-east-1"
|
||||
PINECONE_MAX_BATCH_SIZE = 96
|
||||
|
||||
# retry configuration
|
||||
PINECONE_MAX_RETRY_ATTEMPTS = 3
|
||||
PINECONE_RETRY_BASE_DELAY = 1.0 # seconds
|
||||
PINECONE_RETRY_MAX_DELAY = 60.0 # seconds
|
||||
PINECONE_RETRY_BACKOFF_FACTOR = 2.0
|
||||
|
||||
@@ -135,7 +135,7 @@ def core_memory_append(agent_state: "AgentState", label: str, content: str) -> O
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
label (str): Section of the memory to be edited.
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
@@ -152,7 +152,7 @@ def core_memory_replace(agent_state: "AgentState", label: str, old_content: str,
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
label (str): Section of the memory to be edited.
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
|
||||
@@ -1,8 +1,20 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from letta.otel.tracing import trace_method
|
||||
|
||||
try:
|
||||
from pinecone import IndexEmbed, PineconeAsyncio
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
from pinecone.exceptions.exceptions import (
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
PineconeApiException,
|
||||
ServiceException,
|
||||
UnauthorizedException,
|
||||
)
|
||||
|
||||
PINECONE_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -12,8 +24,12 @@ from letta.constants import (
|
||||
PINECONE_CLOUD,
|
||||
PINECONE_EMBEDDING_MODEL,
|
||||
PINECONE_MAX_BATCH_SIZE,
|
||||
PINECONE_MAX_RETRY_ATTEMPTS,
|
||||
PINECONE_METRIC,
|
||||
PINECONE_REGION,
|
||||
PINECONE_RETRY_BACKOFF_FACTOR,
|
||||
PINECONE_RETRY_BASE_DELAY,
|
||||
PINECONE_RETRY_MAX_DELAY,
|
||||
PINECONE_TEXT_FIELD_NAME,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
@@ -23,6 +39,87 @@ from letta.settings import settings
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def pinecone_retry(
|
||||
max_attempts: int = PINECONE_MAX_RETRY_ATTEMPTS,
|
||||
base_delay: float = PINECONE_RETRY_BASE_DELAY,
|
||||
max_delay: float = PINECONE_RETRY_MAX_DELAY,
|
||||
backoff_factor: float = PINECONE_RETRY_BACKOFF_FACTOR,
|
||||
):
|
||||
"""
|
||||
Decorator to retry Pinecone operations with exponential backoff.
|
||||
|
||||
Args:
|
||||
max_attempts: Maximum number of retry attempts
|
||||
base_delay: Base delay in seconds for the first retry
|
||||
max_delay: Maximum delay in seconds between retries
|
||||
backoff_factor: Factor to increase delay after each failed attempt
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
operation_name = func.__name__
|
||||
start_time = time.time()
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
logger.debug(f"[Pinecone] Starting {operation_name} (attempt {attempt + 1}/{max_attempts})")
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"[Pinecone] {operation_name} completed successfully in {execution_time:.2f}s")
|
||||
return result
|
||||
|
||||
except (ServiceException, PineconeApiException) as e:
|
||||
# retryable server errors
|
||||
if attempt == max_attempts - 1:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"[Pinecone] {operation_name} failed after {max_attempts} attempts in {execution_time:.2f}s: {str(e)}")
|
||||
raise
|
||||
|
||||
# calculate delay with exponential backoff and jitter
|
||||
delay = min(base_delay * (backoff_factor**attempt), max_delay)
|
||||
jitter = random.uniform(0, delay * 0.1) # add up to 10% jitter
|
||||
total_delay = delay + jitter
|
||||
|
||||
logger.warning(
|
||||
f"[Pinecone] {operation_name} failed (attempt {attempt + 1}/{max_attempts}): {str(e)}. Retrying in {total_delay:.2f}s"
|
||||
)
|
||||
await asyncio.sleep(total_delay)
|
||||
|
||||
except (UnauthorizedException, ForbiddenException) as e:
|
||||
# non-retryable auth errors
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"[Pinecone] {operation_name} failed with auth error in {execution_time:.2f}s: {str(e)}")
|
||||
raise
|
||||
|
||||
except NotFoundException as e:
|
||||
# non-retryable not found errors
|
||||
execution_time = time.time() - start_time
|
||||
logger.warning(f"[Pinecone] {operation_name} failed with not found error in {execution_time:.2f}s: {str(e)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# other unexpected errors - retry once then fail
|
||||
if attempt == max_attempts - 1:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"[Pinecone] {operation_name} failed after {max_attempts} attempts in {execution_time:.2f}s: {str(e)}")
|
||||
raise
|
||||
|
||||
delay = min(base_delay * (backoff_factor**attempt), max_delay)
|
||||
jitter = random.uniform(0, delay * 0.1)
|
||||
total_delay = delay + jitter
|
||||
|
||||
logger.warning(
|
||||
f"[Pinecone] {operation_name} failed with unexpected error (attempt {attempt + 1}/{max_attempts}): {str(e)}. Retrying in {total_delay:.2f}s"
|
||||
)
|
||||
await asyncio.sleep(total_delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def should_use_pinecone(verbose: bool = False):
|
||||
if verbose:
|
||||
logger.info(
|
||||
@@ -44,29 +141,42 @@ def should_use_pinecone(verbose: bool = False):
|
||||
)
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def upsert_pinecone_indices():
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
for index_name in get_pinecone_indices():
|
||||
indices = get_pinecone_indices()
|
||||
logger.info(f"[Pinecone] Upserting {len(indices)} indices: {indices}")
|
||||
|
||||
for index_name in indices:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
if not await pc.has_index(index_name):
|
||||
logger.info(f"[Pinecone] Creating index {index_name} with model {PINECONE_EMBEDDING_MODEL}")
|
||||
await pc.create_index_for_model(
|
||||
name=index_name,
|
||||
cloud=PINECONE_CLOUD,
|
||||
region=PINECONE_REGION,
|
||||
embed=IndexEmbed(model=PINECONE_EMBEDDING_MODEL, field_map={"text": PINECONE_TEXT_FIELD_NAME}, metric=PINECONE_METRIC),
|
||||
)
|
||||
logger.info(f"[Pinecone] Successfully created index {index_name}")
|
||||
else:
|
||||
logger.debug(f"[Pinecone] Index {index_name} already exists")
|
||||
|
||||
|
||||
def get_pinecone_indices() -> List[str]:
|
||||
return [settings.pinecone_agent_index, settings.pinecone_source_index]
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
logger.info(f"[Pinecone] Preparing to upsert {len(chunks)} chunks for file {file_id} source {source_id}")
|
||||
|
||||
records = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
record = {
|
||||
@@ -77,14 +187,19 @@ async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, ch
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
logger.debug(f"[Pinecone] Created {len(records)} records for file {file_id}")
|
||||
return await upsert_records_to_pinecone_index(records, actor)
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
logger.info(f"[Pinecone] Deleting records for file {file_id} from index {settings.pinecone_source_index} namespace {namespace}")
|
||||
|
||||
try:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
@@ -95,48 +210,72 @@ async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
|
||||
},
|
||||
namespace=namespace,
|
||||
)
|
||||
logger.info(f"[Pinecone] Successfully deleted records for file {file_id}")
|
||||
except NotFoundException:
|
||||
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
|
||||
logger.warning(f"[Pinecone] Namespace {namespace} not found for file {file_id} and org {actor.organization_id}")
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def delete_source_records_from_pinecone_index(source_id: str, actor: User):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
logger.info(f"[Pinecone] Deleting records for source {source_id} from index {settings.pinecone_source_index} namespace {namespace}")
|
||||
|
||||
try:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
await dense_index.delete(filter={"source_id": {"$eq": source_id}}, namespace=namespace)
|
||||
logger.info(f"[Pinecone] Successfully deleted records for source {source_id}")
|
||||
except NotFoundException:
|
||||
logger.warning(f"Pinecone namespace {namespace} not found for {source_id} and {actor.organization_id}")
|
||||
logger.warning(f"[Pinecone] Namespace {namespace} not found for source {source_id} and org {actor.organization_id}")
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
logger.info(f"[Pinecone] Upserting {len(records)} records to index {settings.pinecone_source_index} for org {actor.organization_id}")
|
||||
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
# Process records in batches to avoid exceeding Pinecone limits
|
||||
# process records in batches to avoid exceeding pinecone limits
|
||||
total_batches = (len(records) + PINECONE_MAX_BATCH_SIZE - 1) // PINECONE_MAX_BATCH_SIZE
|
||||
logger.debug(f"[Pinecone] Processing {total_batches} batches of max {PINECONE_MAX_BATCH_SIZE} records each")
|
||||
|
||||
for i in range(0, len(records), PINECONE_MAX_BATCH_SIZE):
|
||||
batch = records[i : i + PINECONE_MAX_BATCH_SIZE]
|
||||
batch_num = (i // PINECONE_MAX_BATCH_SIZE) + 1
|
||||
|
||||
logger.debug(f"[Pinecone] Upserting batch {batch_num}/{total_batches} with {len(batch)} records")
|
||||
await dense_index.upsert_records(actor.organization_id, batch)
|
||||
|
||||
logger.info(f"[Pinecone] Successfully upserted all {len(records)} records in {total_batches} batches")
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]:
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
logger.info(
|
||||
f"[Pinecone] Searching index {settings.pinecone_source_index} namespace {namespace} with query length {len(query)} chars, limit {limit}"
|
||||
)
|
||||
logger.debug(f"[Pinecone] Search filter: {filter}")
|
||||
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
namespace = actor.organization_id
|
||||
|
||||
try:
|
||||
# Search the dense index with reranking
|
||||
# search the dense index with reranking
|
||||
search_results = await dense_index.search(
|
||||
namespace=namespace,
|
||||
query={
|
||||
@@ -146,17 +285,26 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any],
|
||||
},
|
||||
rerank={"model": "bge-reranker-v2-m3", "top_n": limit, "rank_fields": [PINECONE_TEXT_FIELD_NAME]},
|
||||
)
|
||||
|
||||
result_count = len(search_results.get("matches", []))
|
||||
logger.info(f"[Pinecone] Search completed, found {result_count} matches")
|
||||
return search_results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search Pinecone namespace {actor.organization_id}: {str(e)}")
|
||||
logger.warning(f"[Pinecone] Failed to search namespace {namespace}: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
logger.info(f"[Pinecone] Listing records for file {file_id} from index {settings.pinecone_source_index} namespace {namespace}")
|
||||
logger.debug(f"[Pinecone] List params - limit: {limit}, pagination_token: {pagination_token}")
|
||||
|
||||
try:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
@@ -172,9 +320,14 @@ async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int =
|
||||
result = []
|
||||
async for ids in dense_index.list(**kwargs):
|
||||
result.extend(ids)
|
||||
|
||||
logger.info(f"[Pinecone] Successfully listed {len(result)} records for file {file_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list Pinecone namespace {actor.organization_id}: {str(e)}")
|
||||
logger.warning(f"[Pinecone] Failed to list records for file {file_id} in namespace {namespace}: {str(e)}")
|
||||
raise e
|
||||
|
||||
except NotFoundException:
|
||||
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
|
||||
logger.warning(f"[Pinecone] Namespace {namespace} not found for file {file_id} and org {actor.organization_id}")
|
||||
return []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, Integer, String, Text, UniqueConstraint, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
@@ -11,10 +11,7 @@ from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import SourcePassage
|
||||
from letta.orm.source import Source
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Note that this is NOT organization scoped, this is potentially dangerous if we misuse this
|
||||
@@ -64,18 +61,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
|
||||
chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
file_agents: Mapped[List["FileAgent"]] = relationship(
|
||||
"FileAgent",
|
||||
back_populates="file",
|
||||
lazy="selectin",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True, # ← add this
|
||||
)
|
||||
content: Mapped[Optional["FileContent"]] = relationship(
|
||||
"FileContent",
|
||||
uselist=False,
|
||||
|
||||
@@ -12,7 +12,7 @@ from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.file import FileAgent as PydanticFileAgent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.file import FileMetadata
|
||||
pass
|
||||
|
||||
|
||||
class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -55,6 +55,12 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
nullable=False,
|
||||
doc="ID of the agent",
|
||||
)
|
||||
source_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey("sources.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
doc="ID of the source (denormalized from files.source_id)",
|
||||
)
|
||||
|
||||
file_name: Mapped[str] = mapped_column(
|
||||
String,
|
||||
@@ -78,13 +84,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
back_populates="file_agents",
|
||||
lazy="selectin",
|
||||
)
|
||||
file: Mapped["FileMetadata"] = relationship(
|
||||
"FileMetadata",
|
||||
foreign_keys=[file_id],
|
||||
lazy="selectin",
|
||||
back_populates="file_agents",
|
||||
passive_deletes=True, # ← add this
|
||||
)
|
||||
|
||||
# TODO: This is temporary as we figure out if we want FileBlock as a first class citizen
|
||||
def to_pydantic_block(self) -> PydanticBlock:
|
||||
@@ -99,8 +98,8 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
return PydanticBlock(
|
||||
organization_id=self.organization_id,
|
||||
value=visible_content,
|
||||
label=self.file.file_name,
|
||||
label=self.file_name, # use denormalized file_name instead of self.file.file_name
|
||||
read_only=True,
|
||||
metadata={"source_id": self.file.source_id},
|
||||
metadata={"source_id": self.source_id}, # use denormalized source_id
|
||||
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.agent_passage import AgentPassage
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.llm_batch_item import LLMBatchItem
|
||||
@@ -18,7 +17,6 @@ if TYPE_CHECKING:
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig
|
||||
from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.source_passage import SourcePassage
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
@@ -38,8 +36,6 @@ class Organization(SqlalchemyBase):
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
# mcp_servers: Mapped[List["MCPServer"]] = relationship("MCPServer", back_populates="organization", cascade="all, delete-orphan")
|
||||
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
|
||||
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
|
||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
|
||||
sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
|
||||
"SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@@ -49,11 +49,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
|
||||
file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from")
|
||||
|
||||
@declared_attr
|
||||
def file(cls) -> Mapped["FileMetadata"]:
|
||||
"""Relationship to file"""
|
||||
return relationship("FileMetadata", back_populates="source_passages", lazy="selectin")
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="source_passages", lazy="selectin")
|
||||
@@ -74,11 +69,6 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
{"extend_existing": True},
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
def source(cls) -> Mapped["Source"]:
|
||||
"""Relationship to source"""
|
||||
return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True)
|
||||
|
||||
|
||||
class AgentPassage(BasePassage, AgentMixin):
|
||||
"""Passages created by agents as archival memories"""
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import JSON, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm import FileMetadata
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
@@ -11,10 +10,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import SourcePassage
|
||||
pass
|
||||
|
||||
|
||||
class Source(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -34,16 +30,3 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
||||
instructions: Mapped[str] = mapped_column(nullable=True, doc="instructions for how to use the source")
|
||||
embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.")
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
||||
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent",
|
||||
secondary="sources_agents",
|
||||
back_populates="sources",
|
||||
lazy="selectin",
|
||||
cascade="save-update", # Only propagate save and update operations
|
||||
passive_deletes=True, # Let the database handle deletions
|
||||
)
|
||||
|
||||
@@ -85,6 +85,7 @@ class FileAgent(FileAgentBase):
|
||||
)
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent.")
|
||||
file_id: str = Field(..., description="Unique identifier of the file.")
|
||||
source_id: str = Field(..., description="Unique identifier of the source (denormalized from files.source_id).")
|
||||
file_name: str = Field(..., description="Name of the file.")
|
||||
is_open: bool = Field(True, description="True if the agent currently has the file open.")
|
||||
visible_content: Optional[str] = Field(
|
||||
|
||||
@@ -210,7 +210,7 @@ class BasicBlockMemory(Memory):
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
label (str): Section of the memory to be edited.
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
@@ -226,7 +226,7 @@ class BasicBlockMemory(Memory):
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
label (str): Section of the memory to be edited.
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
|
||||
@@ -272,14 +272,14 @@ async def modify_agent(
|
||||
|
||||
|
||||
@router.get("/{agent_id}/tools", response_model=list[Tool], operation_id="list_agent_tools")
|
||||
def list_agent_tools(
|
||||
async def list_agent_tools(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.list_attached_tools_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool")
|
||||
@@ -1072,7 +1072,7 @@ async def _process_message_background(
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
await server.job_manager.update_job_by_id_async(job_id=job_id, job_update=job_update, actor=actor)
|
||||
await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query
|
||||
from fastapi.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
@@ -45,12 +45,8 @@ async def create_messages_batch(
|
||||
if length > max_bytes:
|
||||
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
|
||||
|
||||
# Reject request if env var is not set
|
||||
if not settings.enable_batch_job_polling:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
|
||||
)
|
||||
logger.warning("Batch job polling is disabled. Enable batch processing by setting LETTA_ENABLE_BATCH_JOB_POLLING to True.")
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
batch_job = BatchJob(
|
||||
|
||||
@@ -391,8 +391,8 @@ async def get_file_metadata(
|
||||
if file_metadata.source_id != source_id:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.")
|
||||
|
||||
if should_use_pinecone() and not file_metadata.is_processing_terminal():
|
||||
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks)
|
||||
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
|
||||
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor)
|
||||
logger.info(
|
||||
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} ({file_metadata.file_name}) in organization {actor.organization_id}"
|
||||
)
|
||||
@@ -402,7 +402,7 @@ async def get_file_metadata(
|
||||
file_status = file_metadata.processing_status
|
||||
else:
|
||||
file_status = FileProcessingStatus.COMPLETED
|
||||
await server.file_manager.update_file_status(
|
||||
file_metadata = await server.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
|
||||
)
|
||||
|
||||
|
||||
@@ -1342,9 +1342,6 @@ class SyncServer(Server):
|
||||
new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
|
||||
assert new_passage_size >= curr_passage_size # in case empty files are added
|
||||
|
||||
# rebuild system prompt and force
|
||||
agent_state = await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
|
||||
# update job status
|
||||
job.status = JobStatus.completed
|
||||
job.metadata["num_passages"] = num_passages
|
||||
|
||||
@@ -107,6 +107,32 @@ class AgentManager:
|
||||
self.identity_manager = IdentityManager()
|
||||
self.file_agent_manager = FileAgentManager()
|
||||
|
||||
@trace_method
|
||||
async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Validate that an agent exists and user has access to it using raw SQL for efficiency.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
agent_id: ID of the agent to validate
|
||||
actor: User performing the action
|
||||
|
||||
Raises:
|
||||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||||
"""
|
||||
agent_check_query = sa.text(
|
||||
"""
|
||||
SELECT 1 FROM agents
|
||||
WHERE id = :agent_id
|
||||
AND organization_id = :org_id
|
||||
AND is_deleted = false
|
||||
"""
|
||||
)
|
||||
agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id})
|
||||
|
||||
if not agent_exists.fetchone():
|
||||
raise NoResultFound(f"Agent with ID {agent_id} not found")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||||
"""
|
||||
@@ -635,24 +661,24 @@ class AgentManager:
|
||||
|
||||
return init_messages
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def append_initial_message_sequence_to_in_context_messages(
|
||||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||||
) -> PydanticAgentState:
|
||||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def append_initial_message_sequence_to_in_context_messages_async(
|
||||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||||
) -> PydanticAgentState:
|
||||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||||
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -773,8 +799,8 @@ class AgentManager:
|
||||
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1125,16 +1151,16 @@ class AgentManager:
|
||||
async with db_registry.async_session() as session:
|
||||
return await AgentModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_agent_by_id_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1147,8 +1173,8 @@ class AgentManager:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
return await agent.to_pydantic_async(include_relationships=include_relationships)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_agents_by_ids_async(
|
||||
self,
|
||||
agent_ids: list[str],
|
||||
@@ -1164,16 +1190,16 @@ class AgentManager:
|
||||
)
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Deletes an agent and its associated relationships.
|
||||
@@ -1220,8 +1246,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Deletes an agent and its associated relationships.
|
||||
@@ -1270,8 +1296,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
@@ -1279,8 +1305,8 @@ class AgentManager:
|
||||
data = schema.dump(agent)
|
||||
return AgentSchema(**data)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_agent: AgentSchema,
|
||||
@@ -1349,8 +1375,8 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
# Per Agent Environment Variable Management
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def _set_environment_variables(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1405,8 +1431,8 @@ class AgentManager:
|
||||
# Return the updated agent state
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
@@ -1422,20 +1448,20 @@ class AgentManager:
|
||||
# TODO: 2) These messages are ordered from oldest to newest
|
||||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||||
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
||||
@@ -1443,8 +1469,8 @@ class AgentManager:
|
||||
# TODO: This is duplicated below
|
||||
# TODO: This is legacy code and should be cleaned up
|
||||
# TODO: A lot of the memory "compilation" should be offset to a separate class
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
@@ -1515,29 +1541,42 @@ class AgentManager:
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
# TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def rebuild_system_prompt_async(
|
||||
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None
|
||||
) -> PydanticAgentState:
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
force=False,
|
||||
update_timestamp=True,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||||
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
# Get the current agent state
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources"], actor=actor)
|
||||
num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||||
agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
|
||||
|
||||
num_messages, num_archival_memories, agent_state = await asyncio.gather(
|
||||
num_messages_task,
|
||||
num_archival_memories_task,
|
||||
agent_state_task,
|
||||
)
|
||||
|
||||
if not tool_rules_solver:
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
|
||||
curr_system_message = await self.get_system_message_async(
|
||||
agent_id=agent_id, actor=actor
|
||||
) # this is the system + memory bank, not just the system prompt
|
||||
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
|
||||
if curr_system_message is None:
|
||||
logger.warning(f"No system message found for agent {agent_state.id} and user {actor}")
|
||||
return agent_state
|
||||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||||
|
||||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||||
|
||||
@@ -1551,7 +1590,7 @@ class AgentManager:
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return agent_state
|
||||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
@@ -1561,9 +1600,6 @@ class AgentManager:
|
||||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||||
memory_edit_timestamp = curr_system_message.created_at
|
||||
|
||||
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
@@ -1582,63 +1618,67 @@ class AgentManager:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out (only if there is a diff)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
temp_message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
message = await self.message_manager.update_message_by_id_async(
|
||||
message_id=curr_system_message.id,
|
||||
message_update=MessageUpdate(**message.model_dump()),
|
||||
actor=actor,
|
||||
)
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
||||
else:
|
||||
return agent_state
|
||||
temp_message.id = curr_system_message.id
|
||||
|
||||
if not dry_run:
|
||||
await self.message_manager.update_message_by_id_async(
|
||||
message_id=curr_system_message.id,
|
||||
message_update=MessageUpdate(**temp_message.model_dump()),
|
||||
actor=actor,
|
||||
)
|
||||
else:
|
||||
curr_system_message = temp_message
|
||||
|
||||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
# TODO: How do we know this?
|
||||
new_messages = [message_ids[0]] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def append_to_in_context_messages_async(
|
||||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||||
) -> PydanticAgentState:
|
||||
@@ -1648,8 +1688,8 @@ class AgentManager:
|
||||
message_ids += [m.id for m in messages]
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def reset_messages_async(
|
||||
self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False
|
||||
) -> PydanticAgentState:
|
||||
@@ -1712,8 +1752,8 @@ class AgentManager:
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
@@ -1756,12 +1796,12 @@ class AgentManager:
|
||||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||||
# rebuild memory - this records the last edited timestamp of the memory
|
||||
# TODO: pass in update timestamp from block edit time
|
||||
agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor)
|
||||
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
# TODO: This will NOT work for new blocks/file blocks added intra-step
|
||||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||||
@@ -1779,8 +1819,8 @@ class AgentManager:
|
||||
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def refresh_file_blocks(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
file_blocks = await self.file_agent_manager.list_files_for_agent(agent_id=agent_state.id, actor=actor, return_as_blocks=True)
|
||||
agent_state.memory.file_blocks = [b for b in file_blocks if b is not None]
|
||||
@@ -1789,8 +1829,8 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
@@ -1820,15 +1860,11 @@ class AgentManager:
|
||||
)
|
||||
|
||||
# Commit the changes
|
||||
await agent.update_async(session, actor=actor)
|
||||
agent = await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Force rebuild of system prompt so that the agent is updated with passage count
|
||||
pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
|
||||
return pydantic_agent
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def append_system_message(self, agent_id: str, content: str, actor: PydanticUser):
|
||||
|
||||
# get the agent
|
||||
@@ -1840,8 +1876,8 @@ class AgentManager:
|
||||
# update agent in-context message IDs
|
||||
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
|
||||
|
||||
# get the agent
|
||||
@@ -1853,28 +1889,8 @@ class AgentManager:
|
||||
# update agent in-context message IDs
|
||||
await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Lists all sources attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to list sources for
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
List[str]: List of source IDs attached to the agent
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Use the lazy-loaded relationship to get sources
|
||||
return [source.to_pydantic() for source in agent.sources]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Lists all sources attached to an agent.
|
||||
@@ -1885,44 +1901,34 @@ class AgentManager:
|
||||
|
||||
Returns:
|
||||
List[str]: List of source IDs attached to the agent
|
||||
|
||||
Raises:
|
||||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
# Validate agent exists and user has access
|
||||
await self._validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# Use the lazy-loaded relationship to get sources
|
||||
return [source.to_pydantic() for source in agent.sources]
|
||||
# Use raw SQL to efficiently fetch sources - much faster than lazy loading
|
||||
# Fast query without relationship loading
|
||||
query = (
|
||||
select(SourceModel)
|
||||
.join(SourcesAgents, SourceModel.id == SourcesAgents.source_id)
|
||||
.where(
|
||||
SourcesAgents.agent_id == agent_id,
|
||||
SourceModel.organization_id == actor.organization_id,
|
||||
SourceModel.is_deleted == False,
|
||||
)
|
||||
.order_by(SourceModel.created_at.desc(), SourceModel.id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
sources = result.scalars().all()
|
||||
|
||||
return [source.to_pydantic() for source in sources]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the source from
|
||||
source_id: ID of the source to detach
|
||||
actor: User performing the action
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Remove the source from the relationship
|
||||
remaining_sources = [s for s in agent.sources if s.id != source_id]
|
||||
|
||||
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
|
||||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||||
|
||||
# Update the sources relationship
|
||||
agent.sources = remaining_sources
|
||||
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
@@ -1931,29 +1937,36 @@ class AgentManager:
|
||||
agent_id: ID of the agent to detach the source from
|
||||
source_id: ID of the source to detach
|
||||
actor: User performing the action
|
||||
|
||||
Raises:
|
||||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
# Validate agent exists and user has access
|
||||
await self._validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# Remove the source from the relationship
|
||||
remaining_sources = [s for s in agent.sources if s.id != source_id]
|
||||
# Check if the source is actually attached to this agent using junction table
|
||||
attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
|
||||
attachment_result = await session.execute(attachment_check_query)
|
||||
attachment = attachment_result.scalar_one_or_none()
|
||||
|
||||
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
|
||||
if not attachment:
|
||||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||||
else:
|
||||
# Delete the association directly from the junction table
|
||||
delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
|
||||
await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
# Update the sources relationship
|
||||
agent.sources = remaining_sources
|
||||
|
||||
# Commit the changes
|
||||
await agent.update_async(session, actor=actor)
|
||||
# Get agent without loading relationships for return value
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block management
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_block_with_label(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1968,8 +1981,8 @@ class AgentManager:
|
||||
return block.to_pydantic()
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_block_with_label_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1984,8 +1997,8 @@ class AgentManager:
|
||||
return block.to_pydantic()
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def modify_block_by_label_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2012,8 +2025,8 @@ class AgentManager:
|
||||
await block.update_async(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_block_with_label(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2037,8 +2050,8 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||||
with db_registry.session() as session:
|
||||
@@ -2067,8 +2080,8 @@ class AgentManager:
|
||||
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -2103,8 +2116,8 @@ class AgentManager:
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def detach_block(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2124,8 +2137,8 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def detach_block_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2145,8 +2158,8 @@ class AgentManager:
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def detach_block_with_label(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2170,8 +2183,8 @@ class AgentManager:
|
||||
# Passage Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -2231,8 +2244,8 @@ class AgentManager:
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -2292,8 +2305,8 @@ class AgentManager:
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_source_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -2340,8 +2353,8 @@ class AgentManager:
|
||||
# Convert to Pydantic models
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_agent_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -2384,8 +2397,8 @@ class AgentManager:
|
||||
# Convert to Pydantic models
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def passage_size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -2465,8 +2478,8 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
# Tool Management
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
@@ -2501,8 +2514,8 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
@@ -2537,8 +2550,8 @@ class AgentManager:
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def attach_missing_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches missing core file tools to an agent.
|
||||
@@ -2569,8 +2582,8 @@ class AgentManager:
|
||||
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def detach_all_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detach all core file tools from an agent.
|
||||
@@ -2596,8 +2609,8 @@ class AgentManager:
|
||||
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
@@ -2630,8 +2643,8 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
@@ -2664,8 +2677,8 @@ class AgentManager:
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""
|
||||
List all tools attached to an agent.
|
||||
@@ -2681,11 +2694,40 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return [tool.to_pydantic() for tool in agent.tools]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_attached_tools_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""
|
||||
List all tools attached to an agent (async version with optimized performance).
|
||||
Uses direct SQL queries to avoid SqlAlchemyBase overhead.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to list tools for.
|
||||
actor: User performing the action.
|
||||
|
||||
Returns:
|
||||
List[PydanticTool]: List of tools attached to the agent.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# lightweight check for agent access
|
||||
await self._validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# direct query for tools via join - much more performant
|
||||
query = (
|
||||
select(ToolModel)
|
||||
.join(ToolsAgents, ToolModel.id == ToolsAgents.tool_id)
|
||||
.where(ToolsAgents.agent_id == agent_id, ToolModel.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
tools = result.scalars().all()
|
||||
return [tool.to_pydantic() for tool in tools]
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tag Management
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_tags(
|
||||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
||||
) -> List[str]:
|
||||
@@ -2719,8 +2761,8 @@ class AgentManager:
|
||||
results = [tag[0] for tag in query.all()]
|
||||
return results
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tags_async(
|
||||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
||||
) -> List[str]:
|
||||
@@ -2759,8 +2801,11 @@ class AgentManager:
|
||||
results = [row[0] for row in result.all()]
|
||||
return results
|
||||
|
||||
@trace_method
|
||||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||||
agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
|
||||
agent_id=agent_id, actor=actor, force=True, dry_run=True
|
||||
)
|
||||
calculator = ContextWindowCalculator()
|
||||
|
||||
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION" and agent_state.llm_config.model_endpoint_type == "anthropic":
|
||||
@@ -2776,5 +2821,7 @@ class AgentManager:
|
||||
actor=actor,
|
||||
token_counter=token_counter,
|
||||
message_manager=self.message_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
system_message_compiled=system_message,
|
||||
num_archival_memories=num_archival_memories,
|
||||
num_messages=num_messages,
|
||||
)
|
||||
|
||||
@@ -23,8 +23,8 @@ logger = get_logger(__name__)
|
||||
class BlockManager:
|
||||
"""Manager class to handle business logic related to Blocks."""
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Create a new block based on the Block schema."""
|
||||
db_block = self.get_block_by_id(block.id, actor)
|
||||
@@ -38,8 +38,8 @@ class BlockManager:
|
||||
block.create(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Create a new block based on the Block schema."""
|
||||
db_block = await self.get_block_by_id_async(block.id, actor)
|
||||
@@ -53,8 +53,8 @@ class BlockManager:
|
||||
await block.create_async(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
|
||||
"""
|
||||
Batch-create multiple Blocks in one transaction for better performance.
|
||||
@@ -77,8 +77,8 @@ class BlockManager:
|
||||
# Convert back to Pydantic
|
||||
return [m.to_pydantic() for m in created_models]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
|
||||
"""
|
||||
Batch-create multiple Blocks in one transaction for better performance.
|
||||
@@ -101,8 +101,8 @@ class BlockManager:
|
||||
# Convert back to Pydantic
|
||||
return [m.to_pydantic() for m in created_models]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
# Safety check for block
|
||||
@@ -117,8 +117,8 @@ class BlockManager:
|
||||
block.update(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
# Safety check for block
|
||||
@@ -133,8 +133,8 @@ class BlockManager:
|
||||
await block.update_async(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Delete a block by its ID."""
|
||||
with db_registry.session() as session:
|
||||
@@ -142,8 +142,8 @@ class BlockManager:
|
||||
block.hard_delete(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Delete a block by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -151,8 +151,8 @@ class BlockManager:
|
||||
await block.hard_delete_async(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_blocks_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -214,8 +214,8 @@ class BlockManager:
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
with db_registry.session() as session:
|
||||
@@ -225,8 +225,8 @@ class BlockManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -236,8 +236,8 @@ class BlockManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||||
from sqlalchemy import select
|
||||
@@ -284,8 +284,8 @@ class BlockManager:
|
||||
|
||||
return pydantic_blocks
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_agents_for_block_async(
|
||||
self,
|
||||
block_id: str,
|
||||
@@ -301,8 +301,8 @@ class BlockManager:
|
||||
agents = await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents_orm])
|
||||
return agents
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def size_async(self, actor: PydanticUser) -> int:
|
||||
"""
|
||||
Get the total count of blocks for the given user.
|
||||
@@ -312,8 +312,8 @@ class BlockManager:
|
||||
|
||||
# Block History Functions
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def checkpoint_block(
|
||||
self,
|
||||
block_id: str,
|
||||
@@ -416,8 +416,8 @@ class BlockManager:
|
||||
updated_block = block.update(db_session=session, actor=actor, no_commit=True)
|
||||
return updated_block
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
|
||||
"""
|
||||
Move the block to the immediately previous checkpoint in BlockHistory.
|
||||
@@ -459,8 +459,8 @@ class BlockManager:
|
||||
session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
|
||||
"""
|
||||
Move the block to the next checkpoint if it exists.
|
||||
@@ -498,8 +498,8 @@ class BlockManager:
|
||||
session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bulk_update_block_values_async(
|
||||
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
|
||||
) -> Optional[List[PydanticBlock]]:
|
||||
|
||||
@@ -4,11 +4,14 @@ from typing import Any, List, Optional, Tuple
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.context_window_calculator.token_counter import TokenCounter
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -56,16 +59,18 @@ class ContextWindowCalculator:
|
||||
return None, 1
|
||||
|
||||
async def calculate_context_window(
|
||||
self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: PydanticUser,
|
||||
token_counter: TokenCounter,
|
||||
message_manager: MessageManager,
|
||||
system_message_compiled: Message,
|
||||
num_archival_memories: int,
|
||||
num_messages: int,
|
||||
) -> ContextWindowOverview:
|
||||
"""Calculate context window information using the provided token counter"""
|
||||
|
||||
# Fetch data concurrently
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
|
||||
passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id),
|
||||
message_manager.size_async(actor=actor, agent_id=agent_state.id),
|
||||
)
|
||||
messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids[1:], actor=actor)
|
||||
in_context_messages = [system_message_compiled] + messages
|
||||
|
||||
# Convert messages to appropriate format
|
||||
converted_messages = token_counter.convert_messages(in_context_messages)
|
||||
@@ -128,8 +133,8 @@ class ContextWindowCalculator:
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_archival_memory=num_archival_memories,
|
||||
num_recall_memory=num_messages,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
external_memory_summary=external_memory_summary,
|
||||
# top-level information
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.utils import count_tokens
|
||||
|
||||
@@ -33,16 +37,34 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
self.client = anthropic_client
|
||||
self.model = model
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_text_tokens(self, text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
if not messages:
|
||||
return 0
|
||||
return await self.client.count_tokens(model=self.model, messages=messages)
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
@@ -58,11 +80,23 @@ class TiktokenCounter(TokenCounter):
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_text_tokens(self, text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return count_tokens(text)
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
if not messages:
|
||||
return 0
|
||||
@@ -70,6 +104,12 @@ class TiktokenCounter(TokenCounter):
|
||||
|
||||
return num_tokens_from_messages(messages=messages, model=self.model)
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
|
||||
@@ -40,6 +40,10 @@ class LineChunker:
|
||||
|
||||
def _chunk_by_lines(self, text: str, preserve_indentation: bool = False) -> List[str]:
|
||||
"""Traditional line-based chunking for code and structured data"""
|
||||
# early stop, can happen if the there's nothing on a specific file
|
||||
if not text:
|
||||
return []
|
||||
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
if preserve_indentation:
|
||||
@@ -57,6 +61,10 @@ class LineChunker:
|
||||
|
||||
def _chunk_by_sentences(self, text: str) -> List[str]:
|
||||
"""Sentence-based chunking for documentation and markup"""
|
||||
# early stop, can happen if the there's nothing on a specific file
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Simple sentence splitting on periods, exclamation marks, and question marks
|
||||
# followed by whitespace or end of string
|
||||
sentence_pattern = r"(?<=[.!?])\s+(?=[A-Z])"
|
||||
@@ -75,6 +83,10 @@ class LineChunker:
|
||||
|
||||
def _chunk_by_characters(self, text: str, target_line_length: int = 100) -> List[str]:
|
||||
"""Character-based wrapping for prose text"""
|
||||
# early stop, can happen if the there's nothing on a specific file
|
||||
if not text:
|
||||
return []
|
||||
|
||||
words = text.split()
|
||||
lines = []
|
||||
current_line = []
|
||||
@@ -110,6 +122,11 @@ class LineChunker:
|
||||
strategy = self._determine_chunking_strategy(file_metadata)
|
||||
text = file_metadata.content
|
||||
|
||||
# early stop, can happen if the there's nothing on a specific file
|
||||
if not text:
|
||||
logger.warning(f"File ({file_metadata}) has no content")
|
||||
return []
|
||||
|
||||
# Apply the appropriate chunking strategy
|
||||
if strategy == ChunkingStrategy.DOCUMENTATION:
|
||||
content_lines = self._chunk_by_sentences(text)
|
||||
|
||||
@@ -25,7 +25,6 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
else EmbeddingConfig.default_config(model_name="letta")
|
||||
)
|
||||
self.embedding_config = embedding_config or self.default_embedding_config
|
||||
self.max_concurrent_requests = 20
|
||||
|
||||
# TODO: Unify to global OpenAI client
|
||||
self.client: OpenAIClient = cast(
|
||||
@@ -48,9 +47,55 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
|
||||
},
|
||||
)
|
||||
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
|
||||
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
|
||||
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
|
||||
|
||||
try:
|
||||
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
|
||||
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
|
||||
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
|
||||
except Exception as e:
|
||||
# if it's a token limit error and we can split, do it
|
||||
if self._is_token_limit_error(e) and len(batch) > 1:
|
||||
logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying")
|
||||
log_event(
|
||||
"embedder.batch_split_retry",
|
||||
{
|
||||
"original_batch_size": len(batch),
|
||||
"error": str(e),
|
||||
"split_size": len(batch) // 2,
|
||||
},
|
||||
)
|
||||
|
||||
# split batch in half
|
||||
mid = len(batch) // 2
|
||||
batch1 = batch[:mid]
|
||||
batch1_indices = batch_indices[:mid]
|
||||
batch2 = batch[mid:]
|
||||
batch2_indices = batch_indices[mid:]
|
||||
|
||||
# retry with smaller batches
|
||||
result1 = await self._embed_batch(batch1, batch1_indices)
|
||||
result2 = await self._embed_batch(batch2, batch2_indices)
|
||||
|
||||
return result1 + result2
|
||||
else:
|
||||
# re-raise for other errors or if batch size is already 1
|
||||
raise
|
||||
|
||||
def _is_token_limit_error(self, error: Exception) -> bool:
|
||||
"""Check if the error is due to token limit exceeded"""
|
||||
# convert to string and check for token limit patterns
|
||||
error_str = str(error).lower()
|
||||
|
||||
# TODO: This is quite brittle, works for now
|
||||
# check for the specific patterns we see in token limit errors
|
||||
is_token_limit = (
|
||||
"max_tokens_per_request" in error_str
|
||||
or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str)
|
||||
or "token limit" in error_str
|
||||
or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str)
|
||||
)
|
||||
|
||||
return is_token_limit
|
||||
|
||||
@trace_method
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
@@ -100,7 +145,7 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
|
||||
log_event(
|
||||
"embedder.concurrent_processing_started",
|
||||
{"concurrent_tasks": len(tasks), "max_concurrent_requests": self.max_concurrent_requests},
|
||||
{"concurrent_tasks": len(tasks)},
|
||||
)
|
||||
results = await asyncio.gather(*tasks)
|
||||
log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)})
|
||||
|
||||
@@ -29,6 +29,7 @@ class FileAgentManager:
|
||||
agent_id: str,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
source_id: str,
|
||||
actor: PydanticUser,
|
||||
is_open: bool = True,
|
||||
visible_content: Optional[str] = None,
|
||||
@@ -47,7 +48,12 @@ class FileAgentManager:
|
||||
if is_open:
|
||||
# Use the efficient LRU + open method
|
||||
closed_files, was_already_open = await self.enforce_max_open_files_and_open(
|
||||
agent_id=agent_id, file_id=file_id, file_name=file_name, actor=actor, visible_content=visible_content or ""
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
file_name=file_name,
|
||||
source_id=source_id,
|
||||
actor=actor,
|
||||
visible_content=visible_content or "",
|
||||
)
|
||||
|
||||
# Get the updated file agent to return
|
||||
@@ -85,6 +91,7 @@ class FileAgentManager:
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
file_name=file_name,
|
||||
source_id=source_id,
|
||||
organization_id=actor.organization_id,
|
||||
is_open=is_open,
|
||||
visible_content=visible_content,
|
||||
@@ -327,7 +334,7 @@ class FileAgentManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def enforce_max_open_files_and_open(
|
||||
self, *, agent_id: str, file_id: str, file_name: str, actor: PydanticUser, visible_content: str
|
||||
self, *, agent_id: str, file_id: str, file_name: str, source_id: str, actor: PydanticUser, visible_content: str
|
||||
) -> tuple[List[str], bool]:
|
||||
"""
|
||||
Efficiently handle LRU eviction and file opening in a single transaction.
|
||||
@@ -336,6 +343,7 @@ class FileAgentManager:
|
||||
agent_id: ID of the agent
|
||||
file_id: ID of the file to open
|
||||
file_name: Name of the file to open
|
||||
source_id: ID of the source (denormalized from files.source_id)
|
||||
actor: User performing the action
|
||||
visible_content: Content to set for the opened file
|
||||
|
||||
@@ -418,6 +426,7 @@ class FileAgentManager:
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
file_name=file_name,
|
||||
source_id=source_id,
|
||||
organization_id=actor.organization_id,
|
||||
is_open=True,
|
||||
visible_content=visible_content,
|
||||
@@ -516,6 +525,7 @@ class FileAgentManager:
|
||||
agent_id=agent_id,
|
||||
file_id=meta.id,
|
||||
file_name=meta.file_name,
|
||||
source_id=meta.source_id,
|
||||
organization_id=actor.organization_id,
|
||||
is_open=is_now_open,
|
||||
visible_content=vc,
|
||||
|
||||
@@ -19,8 +19,8 @@ from letta.utils import enforce_types
|
||||
|
||||
class GroupManager:
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_groups(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -45,22 +45,22 @@ class GroupManager:
|
||||
)
|
||||
return [group.to_pydantic() for group in groups]
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
with db_registry.session() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
|
||||
with db_registry.session() as session:
|
||||
new_group = GroupModel()
|
||||
@@ -150,8 +150,8 @@ class GroupManager:
|
||||
await new_group.create_async(session, actor=actor)
|
||||
return new_group.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
@@ -213,16 +213,16 @@ class GroupManager:
|
||||
await group.update_async(session, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the agent
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
group.hard_delete(session)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_group_messages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -258,8 +258,8 @@ class GroupManager:
|
||||
|
||||
return messages
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with db_registry.session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -272,8 +272,8 @@ class GroupManager:
|
||||
|
||||
session.commit()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
|
||||
with db_registry.session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -284,8 +284,8 @@ class GroupManager:
|
||||
group.update(session, actor=actor)
|
||||
return group.turns_counter
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -309,8 +309,8 @@ class GroupManager:
|
||||
|
||||
return prev_last_processed_message_id
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_last_processed_message_id_and_update_async(
|
||||
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
|
||||
) -> str:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.source import Source as SourceModel
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
@@ -104,9 +108,21 @@ class SourceManager:
|
||||
# Verify source exists and user has permission to access it
|
||||
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
||||
|
||||
# The agents relationship is already loaded due to lazy="selectin" in the Source model
|
||||
# and will be properly filtered by organization_id due to the OrganizationMixin
|
||||
agents_orm = source.agents
|
||||
# Use junction table query instead of relationship to avoid performance issues
|
||||
query = (
|
||||
select(AgentModel)
|
||||
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
|
||||
.where(
|
||||
SourcesAgents.source_id == source_id,
|
||||
AgentModel.organization_id == actor.organization_id if actor else True,
|
||||
AgentModel.is_deleted == False,
|
||||
)
|
||||
.order_by(AgentModel.created_at.desc(), AgentModel.id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
|
||||
@@ -188,7 +188,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
label (str): Section of the memory to be edited.
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
@@ -214,7 +214,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
label (str): Section of the memory to be edited.
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
|
||||
@@ -180,7 +180,12 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
|
||||
# Handle LRU eviction and file opening
|
||||
closed_files, was_already_open = await self.files_agents_manager.enforce_max_open_files_and_open(
|
||||
agent_id=agent_state.id, file_id=file_id, file_name=file_name, actor=self.actor, visible_content=visible_content
|
||||
agent_id=agent_state.id,
|
||||
file_id=file_id,
|
||||
file_name=file_name,
|
||||
source_id=file.source_id,
|
||||
actor=self.actor,
|
||||
visible_content=visible_content,
|
||||
)
|
||||
|
||||
opened_files.append(file_name)
|
||||
|
||||
6
poetry.lock
generated
6
poetry.lock
generated
@@ -3591,14 +3591,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
version = "0.1.198"
|
||||
version = "0.1.197"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "letta_client-0.1.198-py3-none-any.whl", hash = "sha256:08bbc238b128da2552b2a6e54feb3294794b5586e0962ce0bb95bb525109f58f"},
|
||||
{file = "letta_client-0.1.198.tar.gz", hash = "sha256:990c9132423e2955d9c7f7549e5064b2366616232d270e5927788cddba4ef9da"},
|
||||
{file = "letta_client-0.1.197-py3-none-any.whl", hash = "sha256:b01eab01ff87a34e79622cd8d3a2f3da56b6bd730312a268d968576671f6cce0"},
|
||||
{file = "letta_client-0.1.197.tar.gz", hash = "sha256:579571623ccec81422087cb7957d5f580fa0b7a53f8495596e5c56d0213220d8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.8.13"
|
||||
version = "0.8.14"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -116,7 +116,6 @@ google = ["google-genai"]
|
||||
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"]
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
ipykernel = "^6.29.5"
|
||||
|
||||
@@ -306,7 +306,7 @@
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Section of the memory to be edited (persona or human)."
|
||||
"description": "Section of the memory to be edited."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
@@ -343,7 +343,7 @@
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Section of the memory to be edited (persona or human)."
|
||||
"description": "Section of the memory to be edited."
|
||||
},
|
||||
"old_content": {
|
||||
"type": "string",
|
||||
|
||||
@@ -237,7 +237,7 @@ def validate_context_window_overview(
|
||||
|
||||
# 16. Check attached file is visible
|
||||
if attached_file:
|
||||
assert attached_file.visible_content in overview.core_memory
|
||||
assert attached_file.visible_content in overview.core_memory, "File must be attached in core memory"
|
||||
assert '<file status="open"' in overview.core_memory
|
||||
assert "</file>" in overview.core_memory
|
||||
|
||||
|
||||
219
tests/test_file_processor.py
Normal file
219
tests/test_file_processor.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from letta.errors import ErrorCode, LLMBadRequestError
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
|
||||
|
||||
class TestOpenAIEmbedder:
|
||||
"""Test suite for OpenAI embedder functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock user for testing"""
|
||||
user = Mock()
|
||||
user.organization_id = "test_org_id"
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_config(self):
|
||||
"""Create a test embedding config"""
|
||||
return EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=3, # small dimension for testing
|
||||
embedding_chunk_size=300,
|
||||
batch_size=2, # small batch size for testing
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def embedder(self, embedding_config):
|
||||
"""Create OpenAI embedder with test config"""
|
||||
with patch("letta.services.file_processor.embedder.openai_embedder.LLMClient.create") as mock_create:
|
||||
mock_client = Mock()
|
||||
mock_client.handle_llm_error = Mock()
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
embedder = OpenAIEmbedder(embedding_config)
|
||||
embedder.client = mock_client
|
||||
return embedder
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_embedding_generation(self, embedder, mock_user):
|
||||
"""Test successful embedding generation for normal cases"""
|
||||
# mock successful embedding response
|
||||
mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
embedder.client.request_embeddings = AsyncMock(return_value=mock_embeddings)
|
||||
|
||||
chunks = ["chunk 1", "chunk 2"]
|
||||
file_id = "test_file"
|
||||
source_id = "test_source"
|
||||
|
||||
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
||||
|
||||
assert len(passages) == 2
|
||||
assert passages[0].text == "chunk 1"
|
||||
assert passages[1].text == "chunk 2"
|
||||
# embeddings are padded to MAX_EMBEDDING_DIM, so check first 3 values
|
||||
assert passages[0].embedding[:3] == [0.1, 0.2, 0.3]
|
||||
assert passages[1].embedding[:3] == [0.4, 0.5, 0.6]
|
||||
assert passages[0].file_id == file_id
|
||||
assert passages[0].source_id == source_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limit_retry_splits_batch(self, embedder, mock_user):
|
||||
"""Test that token limit errors trigger batch splitting and retry"""
|
||||
# create a mock token limit error
|
||||
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
|
||||
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
|
||||
|
||||
# first call fails with token limit, subsequent calls succeed
|
||||
call_count = 0
|
||||
|
||||
async def mock_request_embeddings(inputs, embedding_config):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1 and len(inputs) == 4: # first call with full batch
|
||||
raise token_limit_error
|
||||
elif len(inputs) == 2: # split batches succeed
|
||||
return [[0.1, 0.2], [0.3, 0.4]] if call_count == 2 else [[0.5, 0.6], [0.7, 0.8]]
|
||||
else:
|
||||
return [[0.1, 0.2]] * len(inputs)
|
||||
|
||||
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
|
||||
|
||||
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
|
||||
file_id = "test_file"
|
||||
source_id = "test_source"
|
||||
|
||||
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
||||
|
||||
# should still get all 4 passages despite the retry
|
||||
assert len(passages) == 4
|
||||
assert all(len(p.embedding) == 4096 for p in passages) # padded to MAX_EMBEDDING_DIM
|
||||
# verify multiple calls were made (original + retries)
|
||||
assert call_count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limit_error_detection(self, embedder):
|
||||
"""Test various token limit error detection patterns"""
|
||||
# test openai BadRequestError with proper structure
|
||||
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
|
||||
openai_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
|
||||
assert embedder._is_token_limit_error(openai_error) is True
|
||||
|
||||
# test error with message but no code
|
||||
mock_error_body_no_code = {"error": {"message": "max_tokens_per_request exceeded"}}
|
||||
openai_error_no_code = openai.BadRequestError(
|
||||
message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body_no_code
|
||||
)
|
||||
assert embedder._is_token_limit_error(openai_error_no_code) is True
|
||||
|
||||
# test fallback string detection
|
||||
generic_error = Exception("Requested 100000 tokens, max 50000 tokens per request")
|
||||
assert embedder._is_token_limit_error(generic_error) is True
|
||||
|
||||
# test non-token errors
|
||||
other_error = Exception("Some other error")
|
||||
assert embedder._is_token_limit_error(other_error) is False
|
||||
|
||||
auth_error = openai.AuthenticationError(
|
||||
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
|
||||
)
|
||||
assert embedder._is_token_limit_error(auth_error) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_token_error_handling(self, embedder, mock_user):
|
||||
"""Test that non-token errors are properly handled and re-raised"""
|
||||
# create a non-token error
|
||||
auth_error = openai.AuthenticationError(
|
||||
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
|
||||
)
|
||||
|
||||
# mock handle_llm_error to return a standardized error
|
||||
handled_error = LLMBadRequestError(message="Handled error", code=ErrorCode.UNAUTHENTICATED)
|
||||
embedder.client.handle_llm_error.return_value = handled_error
|
||||
embedder.client.request_embeddings = AsyncMock(side_effect=auth_error)
|
||||
|
||||
chunks = ["chunk 1"]
|
||||
file_id = "test_file"
|
||||
source_id = "test_source"
|
||||
|
||||
with pytest.raises(LLMBadRequestError) as exc_info:
|
||||
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
||||
|
||||
assert exc_info.value == handled_error
|
||||
embedder.client.handle_llm_error.assert_called_once_with(auth_error)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_item_batch_no_retry(self, embedder, mock_user):
|
||||
"""Test that single-item batches don't retry on token limit errors"""
|
||||
# create a token limit error
|
||||
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
|
||||
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
|
||||
|
||||
handled_error = LLMBadRequestError(message="Handled token limit error", code=ErrorCode.INVALID_ARGUMENT)
|
||||
embedder.client.handle_llm_error.return_value = handled_error
|
||||
embedder.client.request_embeddings = AsyncMock(side_effect=token_limit_error)
|
||||
|
||||
chunks = ["very long chunk that exceeds token limit"]
|
||||
file_id = "test_file"
|
||||
source_id = "test_source"
|
||||
|
||||
with pytest.raises(LLMBadRequestError) as exc_info:
|
||||
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
||||
|
||||
assert exc_info.value == handled_error
|
||||
embedder.client.handle_llm_error.assert_called_once_with(token_limit_error)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_chunks_handling(self, embedder, mock_user):
|
||||
"""Test handling of empty chunks list"""
|
||||
chunks = []
|
||||
file_id = "test_file"
|
||||
source_id = "test_source"
|
||||
|
||||
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
||||
|
||||
assert passages == []
|
||||
# should not call request_embeddings for empty input
|
||||
embedder.client.request_embeddings.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_order_preservation(self, embedder, mock_user):
|
||||
"""Test that embedding order is preserved even with retries"""
|
||||
# set up embedder to split batches (batch_size=2)
|
||||
embedder.embedding_config.batch_size = 2
|
||||
|
||||
# mock responses for each batch
|
||||
async def mock_request_embeddings(inputs, embedding_config):
|
||||
# return embeddings that correspond to input order
|
||||
if inputs == ["chunk 1", "chunk 2"]:
|
||||
return [[0.1, 0.1], [0.2, 0.2]]
|
||||
elif inputs == ["chunk 3", "chunk 4"]:
|
||||
return [[0.3, 0.3], [0.4, 0.4]]
|
||||
else:
|
||||
return [[0.1, 0.1]] * len(inputs)
|
||||
|
||||
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
|
||||
|
||||
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
|
||||
file_id = "test_file"
|
||||
source_id = "test_source"
|
||||
|
||||
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
||||
|
||||
# verify order is preserved
|
||||
assert len(passages) == 4
|
||||
assert passages[0].text == "chunk 1"
|
||||
assert passages[0].embedding[:2] == [0.1, 0.1] # check first 2 values before padding
|
||||
assert passages[1].text == "chunk 2"
|
||||
assert passages[1].embedding[:2] == [0.2, 0.2]
|
||||
assert passages[2].text == "chunk 3"
|
||||
assert passages[2].embedding[:2] == [0.3, 0.3]
|
||||
assert passages[3].text == "chunk 4"
|
||||
assert passages[3].embedding[:2] == [0.4, 0.4]
|
||||
@@ -673,6 +673,7 @@ async def file_attachment(server, default_user, sarah_agent, default_file):
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
visible_content="initial",
|
||||
)
|
||||
@@ -903,6 +904,7 @@ async def test_get_context_window_basic(
|
||||
agent_id=created_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
visible_content="hello",
|
||||
)
|
||||
@@ -7221,6 +7223,7 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
visible_content="hello",
|
||||
)
|
||||
@@ -7243,6 +7246,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
visible_content="first",
|
||||
)
|
||||
@@ -7252,6 +7256,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
is_open=False,
|
||||
visible_content="second",
|
||||
@@ -7326,15 +7331,28 @@ async def test_list_files_and_agents(
|
||||
):
|
||||
# default_file ↔ charles (open)
|
||||
await server.file_agent_manager.attach_file(
|
||||
agent_id=charles_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user
|
||||
agent_id=charles_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
)
|
||||
# default_file ↔ sarah (open)
|
||||
await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id, file_id=default_file.id, file_name=default_file.file_name, actor=default_user
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
)
|
||||
# another_file ↔ sarah (closed)
|
||||
await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id, file_id=another_file.id, file_name=another_file.file_name, actor=default_user, is_open=False
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=another_file.id,
|
||||
file_name=another_file.file_name,
|
||||
source_id=another_file.source_id,
|
||||
actor=default_user,
|
||||
is_open=False,
|
||||
)
|
||||
|
||||
files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user)
|
||||
@@ -7384,6 +7402,7 @@ async def test_org_scoping(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
source_id=default_file.source_id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -7420,6 +7439,7 @@ async def test_mark_access_bulk(server, default_user, sarah_agent, default_sourc
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
visible_content=f"content for {file.file_name}",
|
||||
)
|
||||
@@ -7478,6 +7498,7 @@ async def test_lru_eviction_on_attach(server, default_user, sarah_agent, default
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
visible_content=f"content for {file.file_name}",
|
||||
)
|
||||
@@ -7530,6 +7551,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=files[i].id,
|
||||
file_name=files[i].file_name,
|
||||
source_id=files[i].source_id,
|
||||
actor=default_user,
|
||||
visible_content=f"content for {files[i].file_name}",
|
||||
)
|
||||
@@ -7539,6 +7561,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=files[-1].id,
|
||||
file_name=files[-1].file_name,
|
||||
source_id=files[-1].source_id,
|
||||
actor=default_user,
|
||||
is_open=False,
|
||||
visible_content=f"content for {files[-1].file_name}",
|
||||
@@ -7555,7 +7578,12 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
|
||||
|
||||
# Now "open" the last file using the efficient method
|
||||
closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content"
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=files[-1].id,
|
||||
file_name=files[-1].file_name,
|
||||
source_id=files[-1].source_id,
|
||||
actor=default_user,
|
||||
visible_content="updated content",
|
||||
)
|
||||
|
||||
# Should have closed 1 file (the oldest one)
|
||||
@@ -7603,6 +7631,7 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
visible_content=f"content for {file.file_name}",
|
||||
)
|
||||
@@ -7617,7 +7646,12 @@ async def test_lru_no_eviction_when_reopening_same_file(server, default_user, sa
|
||||
|
||||
# "Reopen" the last file (which is already open)
|
||||
closed_files, was_already_open = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
agent_id=sarah_agent.id, file_id=files[-1].id, file_name=files[-1].file_name, actor=default_user, visible_content="updated content"
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=files[-1].id,
|
||||
file_name=files[-1].file_name,
|
||||
source_id=files[-1].source_id,
|
||||
actor=default_user,
|
||||
visible_content="updated content",
|
||||
)
|
||||
|
||||
# Should not have closed any files since we're within the limit
|
||||
@@ -7645,7 +7679,12 @@ async def test_last_accessed_at_updates_correctly(server, default_user, sarah_ag
|
||||
file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text="test content")
|
||||
|
||||
file_agent, closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id, file_id=file.id, file_name=file.file_name, actor=default_user, visible_content="initial content"
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
visible_content="initial content",
|
||||
)
|
||||
|
||||
initial_time = file_agent.last_accessed_at
|
||||
@@ -7777,6 +7816,7 @@ async def test_attach_files_bulk_lru_eviction(server, default_user, sarah_agent,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
visible_content=f"existing content {i}",
|
||||
)
|
||||
@@ -7842,6 +7882,7 @@ async def test_attach_files_bulk_mixed_existing_new(server, default_user, sarah_
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=existing_file.id,
|
||||
file_name=existing_file.file_name,
|
||||
source_id=existing_file.source_id,
|
||||
actor=default_user,
|
||||
visible_content="old content",
|
||||
is_open=False, # Start as closed
|
||||
|
||||
@@ -1033,3 +1033,36 @@ def test_preview_payload(client: LettaSDKClient, agent):
|
||||
assert payload["user"].startswith("user-")
|
||||
|
||||
print(payload)
|
||||
|
||||
|
||||
def test_agent_tools_list(client: LettaSDKClient):
|
||||
"""Test the optimized agent tools list endpoint for correctness."""
|
||||
# Create a test agent
|
||||
agent_state = client.agents.create(
|
||||
name="test_agent_tools_list",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a helpful assistant.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
include_base_tools=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Test basic functionality
|
||||
tools = client.agents.tools.list(agent_id=agent_state.id)
|
||||
assert len(tools) > 0, "Agent should have base tools attached"
|
||||
|
||||
# Verify tool objects have expected attributes
|
||||
for tool in tools:
|
||||
assert hasattr(tool, "id"), "Tool should have id attribute"
|
||||
assert hasattr(tool, "name"), "Tool should have name attribute"
|
||||
assert tool.id is not None, "Tool id should not be None"
|
||||
assert tool.name is not None, "Tool name should not be None"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
@@ -464,9 +464,8 @@ def test_line_chunker_edge_case_empty_file():
|
||||
file = FileMetadata(file_name="empty.py", source_id="test_source", content="")
|
||||
chunker = LineChunker()
|
||||
|
||||
# Test requesting lines from empty file
|
||||
with pytest.raises(ValueError, match="File empty.py has only 0 lines, but requested offset 1 is out of range"):
|
||||
chunker.chunk_text(file, start=0, end=1, validate_range=True)
|
||||
# no error
|
||||
chunker.chunk_text(file, start=0, end=1, validate_range=True)
|
||||
|
||||
|
||||
def test_line_chunker_edge_case_single_line():
|
||||
|
||||
Reference in New Issue
Block a user