diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 3eebbb24..460a96e2 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -74,7 +74,9 @@ class LettaAgent(BaseAgent): @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse: - agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) + agent_state = await self.agent_manager.get_agent_by_id_async( + agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor + ) _, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps) return _create_letta_response( new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage @@ -82,7 +84,9 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True): - agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) + agent_state = await self.agent_manager.get_agent_by_id_async( + agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor + ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) @@ -294,7 +298,9 @@ class LettaAgent(BaseAgent): 3. Fetches a response from the LLM 4. Processes the response """ - agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) + agent_state = await self.agent_manager.get_agent_by_id_async( + agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor + ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 03794cec..e2355ab5 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -145,7 +145,7 @@ class LettaAgentBatch(BaseAgent): agent_mapping = { agent_state.id: agent_state for agent_state in await self.agent_manager.get_agents_by_ids_async( - agent_ids=[request.agent_id for request in batch_requests], actor=self.actor + agent_ids=[request.agent_id for request in batch_requests], include_relationships=["tools", "memory"], actor=self.actor ) } @@ -300,7 +300,9 @@ class LettaAgentBatch(BaseAgent): provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items} # Fetch agent states in a single call - agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids, actor=self.actor) + agent_states = await self.agent_manager.get_agents_by_ids_async( + agent_ids=agent_ids, include_relationships=["tools", "memory"], actor=self.actor + ) agent_state_map = {agent.id: agent for agent in agent_states} # Process each agent's results diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 1401112b..75ffe7b0 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -429,7 +429,7 @@ async def list_blocks( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) + agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor) return agent.memory.blocks except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3ca6d8ec..915413e5 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -979,18 +979,32 @@ class AgentManager: return agent.to_pydantic() @enforce_types - async def get_agent_by_id_async(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: + async def get_agent_by_id_async( + self, + agent_id: str, + actor: PydanticUser, + include_relationships: Optional[List[str]] = None, + ) -> PydanticAgentState: """Fetch an agent by its ID.""" async with db_registry.async_session() as session: agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) - return await agent.to_pydantic_async() + return await agent.to_pydantic_async(include_relationships=include_relationships) @enforce_types - async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]: + async def get_agents_by_ids_async( + self, + agent_ids: list[str], + actor: PydanticUser, + include_relationships: Optional[List[str]] = None, + ) -> list[PydanticAgentState]: """Fetch a list of agents by their IDs.""" async with db_registry.async_session() as session: - agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor) - return [await agent.to_pydantic_async() for agent in agents] + agents = await AgentModel.read_multiple_async( + db_session=session, + identifiers=agent_ids, + actor=actor, + ) + return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) @enforce_types def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: @@ -1201,7 +1215,7 @@ class AgentManager: @enforce_types async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: - agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor) @enforce_types @@ -1211,7 +1225,7 @@ class AgentManager: @enforce_types async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: - agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor) # TODO: This is duplicated below @@ -1292,7 +1306,7 @@ class AgentManager: Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages """ - agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor) curr_system_message = await self.get_system_message_async( agent_id=agent_id, actor=actor