feat: Add files into agent context window on file upload (#1852)

Co-authored-by: Caren Thomas <carenthomas@gmail.com>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Kevin Lin
2025-05-29 18:19:23 -07:00
committed by GitHub
parent f54c62eacc
commit ed4b28f3e4
14 changed files with 245 additions and 82 deletions

View File

@@ -231,6 +231,7 @@ class Agent(BaseAgent):
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
return True
return False
def _handle_function_error_response(

View File

@@ -290,7 +290,7 @@ MAX_ERROR_MESSAGE_CHAR_LIMIT = 500
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 5000
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 5000
CORE_MEMORY_BLOCK_CHAR_LIMIT: int = 5000
CORE_MEMORY_SOURCE_CHAR_LIMIT: int = 5000
# Function return limits
FUNCTION_RETURN_CHAR_LIMIT = 6000 # ~300 words
BASE_FUNCTION_RETURN_CHAR_LIMIT = 1000000 # very high (we rely on implementation)

View File

@@ -4,8 +4,10 @@ import tempfile
from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, UploadFile
from starlette import status
import letta.constants as constants
from letta.log import get_logger
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.passage import Passage
@@ -13,9 +15,9 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.user import User
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.utils import sanitize_filename
from letta.utils import safe_create_task, sanitize_filename
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
logger = get_logger(__name__)
router = APIRouter(prefix="/sources", tags=["sources"])
@@ -153,7 +155,6 @@ async def delete_source(
async def upload_file_to_source(
file: UploadFile,
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@@ -163,7 +164,8 @@ async def upload_file_to_source(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
if source is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
bytes = file.file.read()
# create job
@@ -175,14 +177,25 @@ async def upload_file_to_source(
job_id = job.id
await server.job_manager.create_job_async(job, actor=actor)
# create background tasks
asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor))
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor))
# sanitize filename
sanitized_filename = sanitize_filename(file.filename)
# create background tasks
safe_create_task(
load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor),
logger=logger,
label="load_file_to_source_async",
)
safe_create_task(
insert_document_into_context_window_async(server, filename=sanitized_filename, source_id=source_id, actor=actor, bytes=bytes),
logger=logger,
label="insert_document_into_context_window_async",
)
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
# return job information
# Is this necessary? Can we just return the job from create_job?
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
assert job is not None, "Job not found"
if job is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Job with id={job_id} not found.")
return job
@@ -246,12 +259,10 @@ async def delete_file_from_source(
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, filename: str, bytes: bytes, actor: User):
# Create a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
# Sanitize the filename
sanitized_filename = sanitize_filename(file.filename)
file_path = os.path.join(tmpdirname, sanitized_filename)
file_path = os.path.join(tmpdirname, filename)
# Write the file to the sanitized path
with open(file_path, "wb") as buffer:
@@ -267,3 +278,8 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
for agent in agents:
if agent.enable_sleeptime:
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
async def insert_document_into_context_window_async(server: SyncServer, filename: str, source_id: str, actor: User, bytes: bytes):
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
await server.insert_document_into_context_window(source, bytes=bytes, filename=filename, actor=actor)

View File

@@ -21,7 +21,7 @@ import letta.system as system
from letta.agent import Agent, save_agent
from letta.agents.letta_agent import LettaAgent
from letta.config import LettaConfig
from letta.constants import LETTA_TOOL_EXECUTION_DIR
from letta.constants import CORE_MEMORY_SOURCE_CHAR_LIMIT, LETTA_TOOL_EXECUTION_DIR
from letta.data_sources.connectors import DataConnector, load_data
from letta.errors import HandleNotFoundError
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
@@ -1363,6 +1363,47 @@ class SyncServer(Server):
)
await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
async def insert_document_into_context_window(self, source: Source, bytes: bytes, filename: str, actor: User) -> None:
"""
Insert the uploaded document into the context window of all agents
attached to the given source.
"""
agent_states = await self.source_manager.list_attached_agents(source_id=source.id, actor=actor)
logger.info(f"Inserting document into context window for source: {source}")
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
passages = bytes.decode("utf-8")[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
async def process_agent(agent_state):
try:
block = await self.agent_manager.get_block_with_label_async(
agent_id=agent_state.id,
block_label=filename,
actor=actor,
)
await self.block_manager.update_block_async(
block_id=block.id,
block_update=BlockUpdate(value=passages),
actor=actor,
)
except NoResultFound:
block = await self.block_manager.create_or_update_block_async(
block=Block(
value=passages,
label=filename,
description="Contains recursive summarizations of the conversation so far",
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
),
actor=actor,
)
await self.agent_manager.attach_block_async(
agent_id=agent_state.id,
block_id=block.id,
actor=actor,
)
await asyncio.gather(*(process_agent(agent) for agent in agent_states))
async def create_document_sleeptime_agent_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
) -> AgentState:

View File

@@ -173,7 +173,6 @@ class SourceManager:
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination."""
async with db_registry.async_session() as session:
files_all = await FileMetadataModel.list_async(db_session=session, organization_id=actor.organization_id, source_id=source_id)
files = await FileMetadataModel.list_async(
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
)

View File

@@ -1018,7 +1018,7 @@ def sanitize_filename(filename: str) -> str:
base = base[:max_base_length]
# Append a unique UUID suffix for uniqueness
unique_suffix = uuid.uuid4().hex
unique_suffix = uuid.uuid4().hex[:4]
sanitized_filename = f"{base}_{unique_suffix}{ext}"
# Return the sanitized filename
@@ -1088,3 +1088,13 @@ def log_telemetry(logger: Logger, event: str, **kwargs):
def make_key(*args, **kwargs):
return str((args, tuple(sorted(kwargs.items()))))
def safe_create_task(coro, logger: Logger, label: str = "background task"):
async def wrapper():
try:
await coro
except Exception as e:
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
return asyncio.create_task(wrapper())

View File

@@ -1,7 +1,7 @@
{
"context_window": 16000,
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"model_endpoint_type": "together",
"model_endpoint": "https://api.together.ai/v1",
"model_wrapper": "chatml"
"context_window": 16000,
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"model_endpoint_type": "together",
"model_endpoint": "https://api.together.ai/v1",
"model_wrapper": "chatml"
}

View File

@@ -1 +1 @@
{}
{}

View File

@@ -420,7 +420,7 @@ def test_load_file(client: RESTClient, agent: AgentState):
# Get the memgpt paper
file = files[0]
# Assert the filename matches the pattern
pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$")
pattern = re.compile(r"^memgpt_paper_[a-f0-9]+\.pdf$")
assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern."
assert file.source_id == source.id

View File

@@ -113,9 +113,8 @@ def test_shared_blocks(client: LettaSDKClient):
)
],
)
assert (
"charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower()
), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value}"
block_value = client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value
assert "charles" in block_value.lower(), f"Shared block update failed {block_value}"
# cleanup
client.agents.delete(agent_state1.id)
@@ -682,7 +681,7 @@ def test_many_blocks(client: LettaSDKClient):
client.agents.delete(agent2.id)
def test_sources(client: LettaSDKClient, agent: AgentState):
def test_sources_crud(client: LettaSDKClient, agent: AgentState):
# Clear existing sources
for source in client.sources.list():

97
tests/test_sources.py Normal file
View File

@@ -0,0 +1,97 @@
import os
import re
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client.types import AgentState
from tests.utils import wait_for_server
# Constants
SERVER_PORT = 8283
def run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(scope="module")
def client() -> LettaSDKClient:
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
print("Running client tests with server:", server_url)
client = LettaSDKClient(base_url=server_url, token=None)
yield client
@pytest.fixture(scope="module")
def agent_state(client: LettaSDKClient):
agent_state = client.agents.create(
memory_blocks=[
CreateBlock(
label="human",
value="username: sarah",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
)
yield agent_state
# delete agent
client.agents.delete(agent_id=agent_state.id)
def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
# Clear existing sources
for source in client.sources.list():
client.sources.delete(source_id=source.id)
# Clear existing jobs
for job in client.jobs.list():
client.jobs.delete(job_id=job.id)
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
assert len(client.sources.list()) == 1
# Attach
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Load files into the source
file_path = "tests/data/test.txt"
# Upload the files
with open(file_path, "rb") as f:
job = client.sources.files.upload(source_id=source.id, file=f)
# Wait for the jobs to complete
while job.status != "completed":
time.sleep(1)
job = client.jobs.retrieve(job_id=job.id)
print("Waiting for jobs to complete...", job.status)
# Get the first file with pagination
files = client.sources.files.list(source_id=source.id, limit=1)
assert len(files) == 1
assert files[0].source_id == source.id
# Get the agent state
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert "test" in [b.value for b in blocks]
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)

View File

@@ -9,21 +9,21 @@
"description": "List of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"required": ["name", "key", "description"]
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
},
"required": ["name", "key", "description"]
}
}
},

View File

@@ -1,35 +1,35 @@
{
"name": "create_task_plan",
"description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": <string> -- Name of the step.\n \"key\": <string> -- Unique identifier for the step.\n \"description\": <string> -- An exhaustic description of what this step is trying to achieve and accomplish.\n}",
"strict": true,
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "array",
"description": "List of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
},
"additionalProperties": false,
"required": ["name", "key", "description"]
}
"name": "create_task_plan",
"description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": <string> -- Name of the step.\n \"key\": <string> -- Unique identifier for the step.\n \"description\": <string> -- An exhaustic description of what this step is trying to achieve and accomplish.\n}",
"strict": true,
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "array",
"description": "List of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
},
"additionalProperties": false,
"required": ["name", "key", "description"]
}
},
"additionalProperties": false,
"required": ["steps"]
}
}
},
"additionalProperties": false,
"required": ["steps"]
}
}

View File

@@ -1,9 +1,9 @@
{
"name": "roll_d20",
"description": "This function generates a random integer between 1 and 20, inclusive,\nwhich represents the outcome of a single roll of a d20.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
"name": "roll_d20",
"description": "This function generates a random integer between 1 and 20, inclusive,\nwhich represents the outcome of a single roll of a d20.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}