feat(asyncify): migrate upload and attach source (#2432)
This commit is contained in:
@@ -98,14 +98,14 @@ async def load_data(
|
||||
embedding_to_document_name[hashable_embedding] = file_name
|
||||
if len(passages) >= 100:
|
||||
# insert passages into passage store
|
||||
passage_manager.create_many_passages(passages, actor)
|
||||
await passage_manager.create_many_passages_async(passages, actor)
|
||||
|
||||
passage_count += len(passages)
|
||||
passages = []
|
||||
|
||||
if len(passages) > 0:
|
||||
# insert passages into passage store
|
||||
passage_manager.create_many_passages(passages, actor)
|
||||
await passage_manager.create_many_passages_async(passages, actor)
|
||||
passage_count += len(passages)
|
||||
|
||||
return passage_count, file_count
|
||||
|
||||
@@ -298,7 +298,7 @@ async def detach_tool(
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
|
||||
def attach_source(
|
||||
async def attach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
@@ -308,8 +308,8 @@ def attach_source(
|
||||
"""
|
||||
Attach a source to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
agent = server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
if agent.enable_sleeptime:
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id)
|
||||
background_tasks.add_task(server.sleeptime_document_ingest, agent, source, actor)
|
||||
@@ -317,7 +317,7 @@ def attach_source(
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent")
|
||||
def detach_source(
|
||||
async def detach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -326,8 +326,8 @@ def detach_source(
|
||||
"""
|
||||
Detach a source from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
agent = server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.detach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
if agent.enable_sleeptime:
|
||||
try:
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id)
|
||||
|
||||
@@ -174,7 +174,7 @@ async def upload_file_to_source(
|
||||
completed_at=None,
|
||||
)
|
||||
job_id = job.id
|
||||
server.job_manager.create_job(job, actor=actor)
|
||||
await server.job_manager.create_job_async(job, actor=actor)
|
||||
|
||||
# create background tasks
|
||||
asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor))
|
||||
@@ -182,7 +182,7 @@ async def upload_file_to_source(
|
||||
|
||||
# return job information
|
||||
# Is this necessary? Can we just return the job from create_job?
|
||||
job = server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
assert job is not None, "Job not found"
|
||||
return job
|
||||
|
||||
|
||||
@@ -1289,9 +1289,9 @@ class SyncServer(Server):
|
||||
async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
|
||||
|
||||
# update job
|
||||
job = self.job_manager.get_job_by_id(job_id, actor=actor)
|
||||
job = await self.job_manager.get_job_by_id_async(job_id, actor=actor)
|
||||
job.status = JobStatus.running
|
||||
self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
||||
await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
||||
|
||||
# try:
|
||||
from letta.data_sources.connectors import DirectoryConnector
|
||||
@@ -1310,18 +1310,18 @@ class SyncServer(Server):
|
||||
|
||||
# Attach source to agent
|
||||
curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
|
||||
agent_state = self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
|
||||
agent_state = await self.agent_manager.attach_source_async(agent_id=agent_state.id, source_id=source_id, actor=actor)
|
||||
new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
|
||||
assert new_passage_size >= curr_passage_size # in case empty files are added
|
||||
|
||||
# rebuild system prompt and force
|
||||
agent_state = self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
|
||||
agent_state = await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
|
||||
# update job status
|
||||
job.status = JobStatus.completed
|
||||
job.metadata["num_passages"] = num_passages
|
||||
job.metadata["num_documents"] = num_documents
|
||||
self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
||||
await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@@ -1711,6 +1711,51 @@ class AgentManager:
|
||||
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the source to
|
||||
source_id: ID of the source to attach
|
||||
actor: User performing the action
|
||||
|
||||
Raises:
|
||||
ValueError: If either agent or source doesn't exist
|
||||
IntegrityError: If the source is already attached to the agent
|
||||
"""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify both agent and source exist and user has permission to access them
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# The _process_relationship helper already handles duplicate checking via unique constraint
|
||||
await _process_relationship_async(
|
||||
session=session,
|
||||
agent=agent,
|
||||
relationship_name="sources",
|
||||
model_class=SourceModel,
|
||||
item_ids=[source_id],
|
||||
allow_partial=False,
|
||||
replace=False, # Extend existing sources rather than replace
|
||||
)
|
||||
|
||||
# Commit the changes
|
||||
await agent.update_async(session, actor=actor)
|
||||
|
||||
# Force rebuild of system prompt so that the agent is updated with passage count
|
||||
# and recent passages and add system message alert to agent
|
||||
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
await self.append_system_message_async(
|
||||
agent_id=agent_id,
|
||||
content=DATA_SOURCE_ATTACH_ALERT,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def append_system_message(self, agent_id: str, content: str, actor: PydanticUser):
|
||||
@@ -1724,6 +1769,19 @@ class AgentManager:
|
||||
# update agent in-context message IDs
|
||||
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
|
||||
|
||||
# get the agent
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||||
)
|
||||
|
||||
# update agent in-context message IDs
|
||||
await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||||
@@ -1792,6 +1850,34 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the source from
|
||||
source_id: ID of the source to detach
|
||||
actor: User performing the action
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Remove the source from the relationship
|
||||
remaining_sources = [s for s in agent.sources if s.id != source_id]
|
||||
|
||||
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
|
||||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||||
|
||||
# Update the sources relationship
|
||||
agent.sources = remaining_sources
|
||||
|
||||
# Commit the changes
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
@@ -137,6 +138,12 @@ class PassageManager:
|
||||
"""Create multiple passages."""
|
||||
return [self.create_passage(p, actor) for p in passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple passages."""
|
||||
return await asyncio.gather(*[self.create_passage_async(p, actor) for p in passages])
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def insert_passage(
|
||||
|
||||
@@ -1098,14 +1098,14 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool,
|
||||
async def test_attach_source(server: SyncServer, sarah_agent, default_source, default_user, event_loop):
|
||||
"""Test attaching a source to an agent."""
|
||||
# Attach the source
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
# Verify attachment through get_agent_by_id
|
||||
agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user)
|
||||
assert default_source.id in [s.id for s in agent.sources]
|
||||
|
||||
# Verify that attaching the same source again doesn't cause issues
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user)
|
||||
assert len([s for s in agent.sources if s.id == default_source.id]) == 1
|
||||
|
||||
@@ -1118,8 +1118,8 @@ async def test_list_attached_source_ids(server: SyncServer, sarah_agent, default
|
||||
assert len(sources) == 0
|
||||
|
||||
# Attach sources
|
||||
server.agent_manager.attach_source(sarah_agent.id, default_source.id, actor=default_user)
|
||||
server.agent_manager.attach_source(sarah_agent.id, other_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(sarah_agent.id, default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(sarah_agent.id, other_source.id, actor=default_user)
|
||||
|
||||
# List sources and verify
|
||||
sources = await server.agent_manager.list_attached_sources_async(sarah_agent.id, actor=default_user)
|
||||
@@ -1133,39 +1133,42 @@ async def test_list_attached_source_ids(server: SyncServer, sarah_agent, default
|
||||
async def test_detach_source(server: SyncServer, sarah_agent, default_source, default_user, event_loop):
|
||||
"""Test detaching a source from an agent."""
|
||||
# Attach source
|
||||
server.agent_manager.attach_source(sarah_agent.id, default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(sarah_agent.id, default_source.id, actor=default_user)
|
||||
|
||||
# Verify it's attached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user)
|
||||
assert default_source.id in [s.id for s in agent.sources]
|
||||
|
||||
# Detach source
|
||||
server.agent_manager.detach_source(sarah_agent.id, default_source.id, actor=default_user)
|
||||
await server.agent_manager.detach_source_async(sarah_agent.id, default_source.id, actor=default_user)
|
||||
|
||||
# Verify it's detached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user)
|
||||
assert default_source.id not in [s.id for s in agent.sources]
|
||||
|
||||
# Verify that detaching an already detached source doesn't cause issues
|
||||
server.agent_manager.detach_source(sarah_agent.id, default_source.id, actor=default_user)
|
||||
await server.agent_manager.detach_source_async(sarah_agent.id, default_source.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user, event_loop):
|
||||
"""Test attaching a source to a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.attach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test attaching a nonexistent source to an agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user)
|
||||
|
||||
|
||||
def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user, event_loop):
|
||||
"""Test detaching a source from a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.detach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.detach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1183,7 +1186,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age
|
||||
assert len(attached_agents) == 0
|
||||
|
||||
# Attach source to first agent
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
# Verify one agent is now attached
|
||||
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
|
||||
@@ -1191,7 +1194,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age
|
||||
assert sarah_agent.id in [a.id for a in attached_agents]
|
||||
|
||||
# Attach source to second agent
|
||||
server.agent_manager.attach_source(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
# Verify both agents are now attached
|
||||
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
|
||||
@@ -1201,7 +1204,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age
|
||||
assert charles_agent.id in attached_agent_ids
|
||||
|
||||
# Detach source from first agent
|
||||
server.agent_manager.detach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.detach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
# Verify only second agent remains attached
|
||||
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
|
||||
@@ -1928,7 +1931,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
|
||||
"blue shoes",
|
||||
]
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
for i, text in enumerate(test_passages):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
@@ -3905,7 +3908,7 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u
|
||||
)
|
||||
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=source.id, actor=default_user)
|
||||
|
||||
# Delete the source
|
||||
deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user