From 4dff6b648a0811a84a58f754d3a57389393aca65 Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 25 May 2025 23:09:14 -0700 Subject: [PATCH] feat(asyncify): migrate upload and attach source (#2432) --- letta/data_sources/connectors.py | 4 +- letta/server/rest_api/routers/v1/agents.py | 12 +-- letta/server/rest_api/routers/v1/sources.py | 4 +- letta/server/server.py | 10 +-- letta/services/agent_manager.py | 86 +++++++++++++++++++++ letta/services/passage_manager.py | 7 ++ tests/test_managers.py | 39 +++++----- 7 files changed, 129 insertions(+), 33 deletions(-) diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 41f728c2..5a0329ec 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -98,14 +98,14 @@ async def load_data( embedding_to_document_name[hashable_embedding] = file_name if len(passages) >= 100: # insert passages into passage store - passage_manager.create_many_passages(passages, actor) + await passage_manager.create_many_passages_async(passages, actor) passage_count += len(passages) passages = [] if len(passages) > 0: # insert passages into passage store - passage_manager.create_many_passages(passages, actor) + await passage_manager.create_many_passages_async(passages, actor) passage_count += len(passages) return passage_count, file_count diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 39a25fa7..862b6f63 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -298,7 +298,7 @@ async def detach_tool( @router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent") -def attach_source( +async def attach_source( agent_id: str, source_id: str, background_tasks: BackgroundTasks, @@ -308,8 +308,8 @@ def attach_source( """ Attach a source to an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - agent = server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + agent = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) if agent.enable_sleeptime: source = server.source_manager.get_source_by_id(source_id=source_id) background_tasks.add_task(server.sleeptime_document_ingest, agent, source, actor) @@ -317,7 +317,7 @@ def attach_source( @router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent") -def detach_source( +async def detach_source( agent_id: str, source_id: str, server: "SyncServer" = Depends(get_letta_server), @@ -326,8 +326,8 @@ def detach_source( """ Detach a source from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - agent = server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + agent = await server.agent_manager.detach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) if agent.enable_sleeptime: try: source = server.source_manager.get_source_by_id(source_id=source_id) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 130a0434..d21ed129 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -174,7 +174,7 @@ async def upload_file_to_source( completed_at=None, ) job_id = job.id - server.job_manager.create_job(job, actor=actor) + await server.job_manager.create_job_async(job, actor=actor) # create background tasks asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor)) @@ -182,7 +182,7 @@ async def upload_file_to_source( # return job information # Is this necessary? Can we just return the job from create_job? - job = server.job_manager.get_job_by_id(job_id=job_id, actor=actor) + job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor) assert job is not None, "Job not found" return job diff --git a/letta/server/server.py b/letta/server/server.py index 8aa3064f..10a56e8a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1289,9 +1289,9 @@ class SyncServer(Server): async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: # update job - job = self.job_manager.get_job_by_id(job_id, actor=actor) + job = await self.job_manager.get_job_by_id_async(job_id, actor=actor) job.status = JobStatus.running - self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) + await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) # try: from letta.data_sources.connectors import DirectoryConnector @@ -1310,18 +1310,18 @@ class SyncServer(Server): # Attach source to agent curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) - agent_state = self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor) + agent_state = await self.agent_manager.attach_source_async(agent_id=agent_state.id, source_id=source_id, actor=actor) new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added # rebuild system prompt and force - agent_state = self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True) + agent_state = await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) # update job status job.status = JobStatus.completed job.metadata["num_passages"] = num_passages job.metadata["num_documents"] = num_documents - self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) + await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) return job diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 31712640..141c97b7 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1711,6 +1711,51 @@ class AgentManager: return agent.to_pydantic() + @trace_method + @enforce_types + async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Attaches a source to an agent. + + Args: + agent_id: ID of the agent to attach the source to + source_id: ID of the source to attach + actor: User performing the action + + Raises: + ValueError: If either agent or source doesn't exist + IntegrityError: If the source is already attached to the agent + """ + + async with db_registry.async_session() as session: + # Verify both agent and source exist and user has permission to access them + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # The _process_relationship helper already handles duplicate checking via unique constraint + await _process_relationship_async( + session=session, + agent=agent, + relationship_name="sources", + model_class=SourceModel, + item_ids=[source_id], + allow_partial=False, + replace=False, # Extend existing sources rather than replace + ) + + # Commit the changes + await agent.update_async(session, actor=actor) + + # Force rebuild of system prompt so that the agent is updated with passage count + # and recent passages and add system message alert to agent + await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) + await self.append_system_message_async( + agent_id=agent_id, + content=DATA_SOURCE_ATTACH_ALERT, + actor=actor, + ) + + return await agent.to_pydantic_async() + @trace_method @enforce_types def append_system_message(self, agent_id: str, content: str, actor: PydanticUser): @@ -1724,6 +1769,19 @@ class AgentManager: # update agent in-context message IDs self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor) + @trace_method + @enforce_types + async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser): + + # get the agent + agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + message = PydanticMessage.dict_to_message( + agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content} + ) + + # update agent in-context message IDs + await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor) + @trace_method @enforce_types def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: @@ -1792,6 +1850,34 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method + @enforce_types + async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Detaches a source from an agent. + + Args: + agent_id: ID of the agent to detach the source from + source_id: ID of the source to detach + actor: User performing the action + """ + async with db_registry.async_session() as session: + # Verify agent exists and user has permission to access it + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Remove the source from the relationship + remaining_sources = [s for s in agent.sources if s.id != source_id] + + if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship + logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}") + + # Update the sources relationship + agent.sources = remaining_sources + + # Commit the changes + await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() + # ====================================================================================================================== # Block management # ====================================================================================================================== diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 554265cc..1b801de2 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from functools import lru_cache from typing import List, Optional @@ -137,6 +138,12 @@ class PassageManager: """Create multiple passages.""" return [self.create_passage(p, actor) for p in passages] + @enforce_types + @trace_method + async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: + """Create multiple passages.""" + return await asyncio.gather(*[self.create_passage_async(p, actor) for p in passages]) + @enforce_types @trace_method def insert_passage( diff --git a/tests/test_managers.py b/tests/test_managers.py index 704eb7cb..146872c9 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1098,14 +1098,14 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool, async def test_attach_source(server: SyncServer, sarah_agent, default_source, default_user, event_loop): """Test attaching a source to an agent.""" # Attach the source - server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) # Verify attachment through get_agent_by_id agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user) assert default_source.id in [s.id for s in agent.sources] # Verify that attaching the same source again doesn't cause issues - server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user) assert len([s for s in agent.sources if s.id == default_source.id]) == 1 @@ -1118,8 +1118,8 @@ async def test_list_attached_source_ids(server: SyncServer, sarah_agent, default assert len(sources) == 0 # Attach sources - server.agent_manager.attach_source(sarah_agent.id, default_source.id, actor=default_user) - server.agent_manager.attach_source(sarah_agent.id, other_source.id, actor=default_user) + await server.agent_manager.attach_source_async(sarah_agent.id, default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(sarah_agent.id, other_source.id, actor=default_user) # List sources and verify sources = await server.agent_manager.list_attached_sources_async(sarah_agent.id, actor=default_user) @@ -1133,39 +1133,42 @@ async def test_list_attached_source_ids(server: SyncServer, sarah_agent, default async def test_detach_source(server: SyncServer, sarah_agent, default_source, default_user, event_loop): """Test detaching a source from an agent.""" # Attach source - server.agent_manager.attach_source(sarah_agent.id, default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(sarah_agent.id, default_source.id, actor=default_user) # Verify it's attached agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user) assert default_source.id in [s.id for s in agent.sources] # Detach source - server.agent_manager.detach_source(sarah_agent.id, default_source.id, actor=default_user) + await server.agent_manager.detach_source_async(sarah_agent.id, default_source.id, actor=default_user) # Verify it's detached agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user) assert default_source.id not in [s.id for s in agent.sources] # Verify that detaching an already detached source doesn't cause issues - server.agent_manager.detach_source(sarah_agent.id, default_source.id, actor=default_user) + await server.agent_manager.detach_source_async(sarah_agent.id, default_source.id, actor=default_user) -def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user): +@pytest.mark.asyncio +async def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user, event_loop): """Test attaching a source to a nonexistent agent.""" with pytest.raises(NoResultFound): - server.agent_manager.attach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) -def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user, event_loop): """Test attaching a nonexistent source to an agent.""" with pytest.raises(NoResultFound): - server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user) -def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user): +@pytest.mark.asyncio +async def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user, event_loop): """Test detaching a source from a nonexistent agent.""" with pytest.raises(NoResultFound): - server.agent_manager.detach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) + await server.agent_manager.detach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) @pytest.mark.asyncio @@ -1183,7 +1186,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age assert len(attached_agents) == 0 # Attach source to first agent - server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) # Verify one agent is now attached attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) @@ -1191,7 +1194,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age assert sarah_agent.id in [a.id for a in attached_agents] # Attach source to second agent - server.agent_manager.attach_source(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user) # Verify both agents are now attached attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) @@ -1201,7 +1204,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age assert charles_agent.id in attached_agent_ids # Detach source from first agent - server.agent_manager.detach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + await server.agent_manager.detach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) # Verify only second agent remains attached attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) @@ -1928,7 +1931,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age "blue shoes", ] - server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) for i, text in enumerate(test_passages): embedding = embed_model.get_text_embedding(text) @@ -3905,7 +3908,7 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u ) source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) - server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=source.id, actor=default_user) # Delete the source deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user)