feat: only include relevant relationships on agent fetch (#2319)

This commit is contained in:
cthomas
2025-05-21 15:58:55 -07:00
committed by GitHub
parent 353bd607df
commit f10a15a8ac
4 changed files with 36 additions and 14 deletions

View File

@@ -74,7 +74,9 @@ class LettaAgent(BaseAgent):
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
)
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
return _create_letta_response(
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
@@ -82,7 +84,9 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
@@ -294,7 +298,9 @@ class LettaAgent(BaseAgent):
3. Fetches a response from the LLM
4. Processes the response
"""
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)

View File

@@ -145,7 +145,7 @@ class LettaAgentBatch(BaseAgent):
agent_mapping = {
agent_state.id: agent_state
for agent_state in await self.agent_manager.get_agents_by_ids_async(
agent_ids=[request.agent_id for request in batch_requests], actor=self.actor
agent_ids=[request.agent_id for request in batch_requests], include_relationships=["tools", "memory"], actor=self.actor
)
}
@@ -300,7 +300,9 @@ class LettaAgentBatch(BaseAgent):
provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items}
# Fetch agent states in a single call
agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids, actor=self.actor)
agent_states = await self.agent_manager.get_agents_by_ids_async(
agent_ids=agent_ids, include_relationships=["tools", "memory"], actor=self.actor
)
agent_state_map = {agent.id: agent for agent in agent_states}
# Process each agent's results

View File

@@ -429,7 +429,7 @@ async def list_blocks(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
return agent.memory.blocks
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -979,18 +979,32 @@ class AgentManager:
return agent.to_pydantic()
@enforce_types
async def get_agent_by_id_async(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
async def get_agent_by_id_async(
self,
agent_id: str,
actor: PydanticUser,
include_relationships: Optional[List[str]] = None,
) -> PydanticAgentState:
"""Fetch an agent by its ID."""
async with db_registry.async_session() as session:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
return await agent.to_pydantic_async()
return await agent.to_pydantic_async(include_relationships=include_relationships)
@enforce_types
async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]:
async def get_agents_by_ids_async(
self,
agent_ids: list[str],
actor: PydanticUser,
include_relationships: Optional[List[str]] = None,
) -> list[PydanticAgentState]:
"""Fetch a list of agents by their IDs."""
async with db_registry.async_session() as session:
agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor)
return [await agent.to_pydantic_async() for agent in agents]
agents = await AgentModel.read_multiple_async(
db_session=session,
identifiers=agent_ids,
actor=actor,
)
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
@enforce_types
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
@@ -1201,7 +1215,7 @@ class AgentManager:
@enforce_types
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
@enforce_types
@@ -1211,7 +1225,7 @@ class AgentManager:
@enforce_types
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
# TODO: This is duplicated below
@@ -1292,7 +1306,7 @@ class AgentManager:
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
curr_system_message = await self.get_system_message_async(
agent_id=agent_id, actor=actor