feat: Insert file blocks for agent on source attach (#2545)
This commit is contained in:
@@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import CORE_MEMORY_SOURCE_CHAR_LIMIT, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.log import get_logger
|
||||
@@ -36,6 +36,7 @@ from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
@@ -301,7 +302,6 @@ async def detach_tool(
|
||||
async def attach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@@ -309,11 +309,30 @@ async def attach_source(
|
||||
Attach a source to an agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
if agent.enable_sleeptime:
|
||||
agent_state = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
files = await server.source_manager.list_files(source_id, actor)
|
||||
texts = []
|
||||
filenames = []
|
||||
for f in files:
|
||||
passages = await server.passage_manager.list_passages_by_file_id_async(file_id=f.id, actor=actor)
|
||||
passage_text = ""
|
||||
for p in passages:
|
||||
if len(passage_text) <= CORE_MEMORY_SOURCE_CHAR_LIMIT:
|
||||
passage_text += p.text
|
||||
|
||||
texts.append(passage_text)
|
||||
filenames.append(f.file_name)
|
||||
|
||||
await server.insert_documents_into_context_window(agent_state=agent_state, texts=texts, filenames=filenames, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
background_tasks.add_task(server.sleeptime_document_ingest_async, agent, source, actor)
|
||||
return agent
|
||||
safe_create_task(
|
||||
server.sleeptime_document_ingest_async(agent_state, source, actor), logger=logger, label="sleeptime_document_ingest_async"
|
||||
)
|
||||
|
||||
return agent_state
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent")
|
||||
|
||||
@@ -168,6 +168,11 @@ async def upload_file_to_source(
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
|
||||
bytes = file.file.read()
|
||||
|
||||
try:
|
||||
text = bytes.decode("utf-8")
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
# create job
|
||||
job = Job(
|
||||
user_id=actor.id,
|
||||
@@ -180,17 +185,15 @@ async def upload_file_to_source(
|
||||
# sanitize filename
|
||||
sanitized_filename = sanitize_filename(file.filename)
|
||||
|
||||
# Add blocks
|
||||
await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=sanitized_filename, actor=actor)
|
||||
|
||||
# create background tasks
|
||||
safe_create_task(
|
||||
load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor),
|
||||
logger=logger,
|
||||
label="load_file_to_source_async",
|
||||
)
|
||||
safe_create_task(
|
||||
insert_document_into_context_window_async(server, filename=sanitized_filename, source_id=source_id, actor=actor, bytes=bytes),
|
||||
logger=logger,
|
||||
label="insert_document_into_context_window_async",
|
||||
)
|
||||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
|
||||
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
@@ -278,8 +281,3 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
|
||||
|
||||
|
||||
async def insert_document_into_context_window_async(server: SyncServer, filename: str, source_id: str, actor: User, bytes: bytes):
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
await server.insert_document_into_context_window(source, bytes=bytes, filename=filename, actor=actor)
|
||||
|
||||
@@ -1363,46 +1363,71 @@ class SyncServer(Server):
|
||||
)
|
||||
await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
|
||||
|
||||
async def insert_document_into_context_window(self, source: Source, bytes: bytes, filename: str, actor: User) -> None:
|
||||
async def _upsert_document_block(self, agent_id: str, text: str, filename: str, actor: User) -> None:
|
||||
"""
|
||||
Internal method to create or update a document block for an agent.
|
||||
This is the shared logic between single and multiple document insertion.
|
||||
"""
|
||||
truncated_text = text[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
|
||||
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(
|
||||
agent_id=agent_id,
|
||||
block_label=filename,
|
||||
actor=actor,
|
||||
)
|
||||
await self.block_manager.update_block_async(
|
||||
block_id=block.id,
|
||||
block_update=BlockUpdate(value=truncated_text),
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
block = await self.block_manager.create_or_update_block_async(
|
||||
block=Block(
|
||||
value=truncated_text,
|
||||
label=filename,
|
||||
description=f"Contains the parsed contents of external file {filename}",
|
||||
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
await self.agent_manager.attach_block_async(
|
||||
agent_id=agent_id,
|
||||
block_id=block.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def insert_document_into_context_windows(self, source_id: str, text: str, filename: str, actor: User) -> None:
|
||||
"""
|
||||
Insert the uploaded document into the context window of all agents
|
||||
attached to the given source.
|
||||
"""
|
||||
agent_states = await self.source_manager.list_attached_agents(source_id=source.id, actor=actor)
|
||||
logger.info(f"Inserting document into context window for source: {source}")
|
||||
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
|
||||
# Return early
|
||||
if not agent_states:
|
||||
return
|
||||
|
||||
logger.info(f"Inserting document into context window for source: {source_id}")
|
||||
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
|
||||
|
||||
passages = bytes.decode("utf-8")[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
|
||||
await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states))
|
||||
|
||||
async def process_agent(agent_state):
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(
|
||||
agent_id=agent_state.id,
|
||||
block_label=filename,
|
||||
actor=actor,
|
||||
)
|
||||
await self.block_manager.update_block_async(
|
||||
block_id=block.id,
|
||||
block_update=BlockUpdate(value=passages),
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
block = await self.block_manager.create_or_update_block_async(
|
||||
block=Block(
|
||||
value=passages,
|
||||
label=filename,
|
||||
description="Contains recursive summarizations of the conversation so far",
|
||||
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
await self.agent_manager.attach_block_async(
|
||||
agent_id=agent_state.id,
|
||||
block_id=block.id,
|
||||
actor=actor,
|
||||
)
|
||||
async def insert_documents_into_context_window(
|
||||
self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User
|
||||
) -> None:
|
||||
"""
|
||||
Insert the uploaded documents into the context window of an agent
|
||||
attached to the given source.
|
||||
"""
|
||||
logger.info(f"Inserting documents into context window for agent_state: {agent_state.id}")
|
||||
|
||||
await asyncio.gather(*(process_agent(agent) for agent in agent_states))
|
||||
if len(texts) != len(filenames):
|
||||
raise ValueError(f"Mismatch between number of texts ({len(texts)}) and filenames ({len(filenames)})")
|
||||
|
||||
await asyncio.gather(
|
||||
*(self._upsert_document_block(agent_state.id, text, filename, actor) for text, filename in zip(texts, filenames))
|
||||
)
|
||||
|
||||
async def create_document_sleeptime_agent_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Optional
|
||||
|
||||
from async_lru import alru_cache
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||
@@ -448,3 +449,16 @@ class PassageManager:
|
||||
BYTES_PER_EMBEDDING_DIM = 4
|
||||
GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM
|
||||
return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_passages_by_file_id_async(self, file_id: str, actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""
|
||||
List all source passages associated with a given file_id.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
passages = result.scalars().all()
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@@ -38,7 +38,7 @@ def client() -> LettaSDKClient:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def agent_state(client: LettaSDKClient):
|
||||
agent_state = client.agents.create(
|
||||
memory_blocks=[
|
||||
@@ -90,7 +90,48 @@ def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, age
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Get the agent state
|
||||
# Get the agent state, check blocks exist
|
||||
blocks = client.agents.blocks.list(agent_id=agent_state.id)
|
||||
assert len(blocks) == 2
|
||||
assert "test" in [b.value for b in blocks]
|
||||
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Clear existing sources
|
||||
for source in client.sources.list():
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
# Clear existing jobs
|
||||
for job in client.jobs.list():
|
||||
client.jobs.delete(job_id=job.id)
|
||||
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
|
||||
assert len(client.sources.list()) == 1
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
# Upload the files
|
||||
with open(file_path, "rb") as f:
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Wait for the jobs to complete
|
||||
while job.status != "completed":
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
print("Waiting for jobs to complete...", job.status)
|
||||
|
||||
# Get the first file with pagination
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Attach after uploading the file
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Get the agent state, check blocks exist
|
||||
blocks = client.agents.blocks.list(agent_id=agent_state.id)
|
||||
assert len(blocks) == 2
|
||||
assert "test" in [b.value for b in blocks]
|
||||
|
||||
Reference in New Issue
Block a user