feat: Add files into agent context window on file upload (#1852)
Co-authored-by: Caren Thomas <carenthomas@gmail.com> Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -231,6 +231,7 @@ class Agent(BaseAgent):
|
||||
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _handle_function_error_response(
|
||||
|
||||
@@ -290,7 +290,7 @@ MAX_ERROR_MESSAGE_CHAR_LIMIT = 500
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 5000
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 5000
|
||||
CORE_MEMORY_BLOCK_CHAR_LIMIT: int = 5000
|
||||
|
||||
CORE_MEMORY_SOURCE_CHAR_LIMIT: int = 5000
|
||||
# Function return limits
|
||||
FUNCTION_RETURN_CHAR_LIMIT = 6000 # ~300 words
|
||||
BASE_FUNCTION_RETURN_CHAR_LIMIT = 1000000 # very high (we rely on implementation)
|
||||
|
||||
@@ -4,8 +4,10 @@ import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, UploadFile
|
||||
from starlette import status
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -13,9 +15,9 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.utils import sanitize_filename
|
||||
from letta.utils import safe_create_task, sanitize_filename
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
@@ -153,7 +155,6 @@ async def delete_source(
|
||||
async def upload_file_to_source(
|
||||
file: UploadFile,
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
@@ -163,7 +164,8 @@ async def upload_file_to_source(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
if source is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
|
||||
bytes = file.file.read()
|
||||
|
||||
# create job
|
||||
@@ -175,14 +177,25 @@ async def upload_file_to_source(
|
||||
job_id = job.id
|
||||
await server.job_manager.create_job_async(job, actor=actor)
|
||||
|
||||
# create background tasks
|
||||
asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor))
|
||||
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor))
|
||||
# sanitize filename
|
||||
sanitized_filename = sanitize_filename(file.filename)
|
||||
|
||||
# create background tasks
|
||||
safe_create_task(
|
||||
load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor),
|
||||
logger=logger,
|
||||
label="load_file_to_source_async",
|
||||
)
|
||||
safe_create_task(
|
||||
insert_document_into_context_window_async(server, filename=sanitized_filename, source_id=source_id, actor=actor, bytes=bytes),
|
||||
logger=logger,
|
||||
label="insert_document_into_context_window_async",
|
||||
)
|
||||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
|
||||
|
||||
# return job information
|
||||
# Is this necessary? Can we just return the job from create_job?
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
assert job is not None, "Job not found"
|
||||
if job is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Job with id={job_id} not found.")
|
||||
return job
|
||||
|
||||
|
||||
@@ -246,12 +259,10 @@ async def delete_file_from_source(
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
|
||||
|
||||
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
|
||||
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, filename: str, bytes: bytes, actor: User):
|
||||
# Create a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# Sanitize the filename
|
||||
sanitized_filename = sanitize_filename(file.filename)
|
||||
file_path = os.path.join(tmpdirname, sanitized_filename)
|
||||
file_path = os.path.join(tmpdirname, filename)
|
||||
|
||||
# Write the file to the sanitized path
|
||||
with open(file_path, "wb") as buffer:
|
||||
@@ -267,3 +278,8 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
|
||||
|
||||
|
||||
async def insert_document_into_context_window_async(server: SyncServer, filename: str, source_id: str, actor: User, bytes: bytes):
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
await server.insert_document_into_context_window(source, bytes=bytes, filename=filename, actor=actor)
|
||||
|
||||
@@ -21,7 +21,7 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.constants import CORE_MEMORY_SOURCE_CHAR_LIMIT, LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import HandleNotFoundError
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
@@ -1363,6 +1363,47 @@ class SyncServer(Server):
|
||||
)
|
||||
await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
|
||||
|
||||
async def insert_document_into_context_window(self, source: Source, bytes: bytes, filename: str, actor: User) -> None:
|
||||
"""
|
||||
Insert the uploaded document into the context window of all agents
|
||||
attached to the given source.
|
||||
"""
|
||||
agent_states = await self.source_manager.list_attached_agents(source_id=source.id, actor=actor)
|
||||
logger.info(f"Inserting document into context window for source: {source}")
|
||||
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
|
||||
|
||||
passages = bytes.decode("utf-8")[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
|
||||
|
||||
async def process_agent(agent_state):
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(
|
||||
agent_id=agent_state.id,
|
||||
block_label=filename,
|
||||
actor=actor,
|
||||
)
|
||||
await self.block_manager.update_block_async(
|
||||
block_id=block.id,
|
||||
block_update=BlockUpdate(value=passages),
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
block = await self.block_manager.create_or_update_block_async(
|
||||
block=Block(
|
||||
value=passages,
|
||||
label=filename,
|
||||
description="Contains recursive summarizations of the conversation so far",
|
||||
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
await self.agent_manager.attach_block_async(
|
||||
agent_id=agent_state.id,
|
||||
block_id=block.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
await asyncio.gather(*(process_agent(agent) for agent in agent_states))
|
||||
|
||||
async def create_document_sleeptime_agent_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
) -> AgentState:
|
||||
|
||||
@@ -173,7 +173,6 @@ class SourceManager:
|
||||
) -> List[PydanticFileMetadata]:
|
||||
"""List all files with optional pagination."""
|
||||
async with db_registry.async_session() as session:
|
||||
files_all = await FileMetadataModel.list_async(db_session=session, organization_id=actor.organization_id, source_id=source_id)
|
||||
files = await FileMetadataModel.list_async(
|
||||
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
|
||||
)
|
||||
|
||||
@@ -1018,7 +1018,7 @@ def sanitize_filename(filename: str) -> str:
|
||||
base = base[:max_base_length]
|
||||
|
||||
# Append a unique UUID suffix for uniqueness
|
||||
unique_suffix = uuid.uuid4().hex
|
||||
unique_suffix = uuid.uuid4().hex[:4]
|
||||
sanitized_filename = f"{base}_{unique_suffix}{ext}"
|
||||
|
||||
# Return the sanitized filename
|
||||
@@ -1088,3 +1088,13 @@ def log_telemetry(logger: Logger, event: str, **kwargs):
|
||||
|
||||
def make_key(*args, **kwargs):
|
||||
return str((args, tuple(sorted(kwargs.items()))))
|
||||
|
||||
|
||||
def safe_create_task(coro, logger: Logger, label: str = "background task"):
|
||||
async def wrapper():
|
||||
try:
|
||||
await coro
|
||||
except Exception as e:
|
||||
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
|
||||
|
||||
return asyncio.create_task(wrapper())
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"context_window": 16000,
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"model_endpoint_type": "together",
|
||||
"model_endpoint": "https://api.together.ai/v1",
|
||||
"model_wrapper": "chatml"
|
||||
"context_window": 16000,
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"model_endpoint_type": "together",
|
||||
"model_endpoint": "https://api.together.ai/v1",
|
||||
"model_wrapper": "chatml"
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
{}
|
||||
{}
|
||||
|
||||
@@ -420,7 +420,7 @@ def test_load_file(client: RESTClient, agent: AgentState):
|
||||
# Get the memgpt paper
|
||||
file = files[0]
|
||||
# Assert the filename matches the pattern
|
||||
pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$")
|
||||
pattern = re.compile(r"^memgpt_paper_[a-f0-9]+\.pdf$")
|
||||
assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern."
|
||||
|
||||
assert file.source_id == source.id
|
||||
|
||||
@@ -113,9 +113,8 @@ def test_shared_blocks(client: LettaSDKClient):
|
||||
)
|
||||
],
|
||||
)
|
||||
assert (
|
||||
"charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower()
|
||||
), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value}"
|
||||
block_value = client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value
|
||||
assert "charles" in block_value.lower(), f"Shared block update failed {block_value}"
|
||||
|
||||
# cleanup
|
||||
client.agents.delete(agent_state1.id)
|
||||
@@ -682,7 +681,7 @@ def test_many_blocks(client: LettaSDKClient):
|
||||
client.agents.delete(agent2.id)
|
||||
|
||||
|
||||
def test_sources(client: LettaSDKClient, agent: AgentState):
|
||||
def test_sources_crud(client: LettaSDKClient, agent: AgentState):
|
||||
|
||||
# Clear existing sources
|
||||
for source in client.sources.list():
|
||||
|
||||
97
tests/test_sources.py
Normal file
97
tests/test_sources.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client.types import AgentState
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client() -> LettaSDKClient:
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
print("Running client tests with server:", server_url)
|
||||
client = LettaSDKClient(base_url=server_url, token=None)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state(client: LettaSDKClient):
|
||||
agent_state = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="username: sarah",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Clear existing sources
|
||||
for source in client.sources.list():
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
# Clear existing jobs
|
||||
for job in client.jobs.list():
|
||||
client.jobs.delete(job_id=job.id)
|
||||
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
|
||||
assert len(client.sources.list()) == 1
|
||||
|
||||
# Attach
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
# Upload the files
|
||||
with open(file_path, "rb") as f:
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Wait for the jobs to complete
|
||||
while job.status != "completed":
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
print("Waiting for jobs to complete...", job.status)
|
||||
|
||||
# Get the first file with pagination
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Get the agent state
|
||||
blocks = client.agents.blocks.list(agent_id=agent_state.id)
|
||||
assert len(blocks) == 2
|
||||
assert "test" in [b.value for b in blocks]
|
||||
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
@@ -9,21 +9,21 @@
|
||||
"description": "List of steps to add to the task plan.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the step."
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the step."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
|
||||
}
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the step."
|
||||
},
|
||||
"required": ["name", "key", "description"]
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the step."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
|
||||
}
|
||||
},
|
||||
"required": ["name", "key", "description"]
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
{
|
||||
"name": "create_task_plan",
|
||||
"description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": <string> -- Name of the step.\n \"key\": <string> -- Unique identifier for the step.\n \"description\": <string> -- An exhaustic description of what this step is trying to achieve and accomplish.\n}",
|
||||
"strict": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"description": "List of steps to add to the task plan.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the step."
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the step."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["name", "key", "description"]
|
||||
}
|
||||
"name": "create_task_plan",
|
||||
"description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": <string> -- Name of the step.\n \"key\": <string> -- Unique identifier for the step.\n \"description\": <string> -- An exhaustic description of what this step is trying to achieve and accomplish.\n}",
|
||||
"strict": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"description": "List of steps to add to the task plan.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the step."
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the step."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["name", "key", "description"]
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["steps"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["steps"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
{
|
||||
"name": "roll_d20",
|
||||
"description": "This function generates a random integer between 1 and 20, inclusive,\nwhich represents the outcome of a single roll of a d20.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
"name": "roll_d20",
|
||||
"description": "This function generates a random integer between 1 and 20, inclusive,\nwhich represents the outcome of a single roll of a d20.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user