Files
letta-server/letta/agents/low_latency_agent.py
2025-03-14 08:47:21 -07:00

325 lines
14 KiB
Python

import json
import uuid
from typing import Any, AsyncGenerator, Dict, List, Tuple
import openai
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_agent import EphemeralAgent
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import (
add_pre_execution_message,
enable_strict_mode,
execute_external_tool,
remove_request_heartbeat,
)
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.block import BlockUpdate
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import (
AssistantMessage,
ChatCompletionRequest,
Tool,
ToolCall,
ToolCallFunction,
ToolMessage,
UserMessage,
)
from letta.schemas.user import User
from letta.server.rest_api.utils import (
convert_letta_messages_to_openai,
create_assistant_messages_from_openai_response,
create_tool_call_messages_from_openai_response,
create_user_message,
)
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.utils import united_diff
logger = get_logger(__name__)
class LowLatencyAgent(BaseAgent):
"""
A function-calling loop for streaming OpenAI responses with tool execution.
This agent:
- Streams partial tokens in real-time for low-latency output.
- Detects tool calls and invokes external tools.
- Gracefully handles OpenAI API failures (429, etc.) and streams errors.
"""
def __init__(
self,
agent_id: str,
openai_client: openai.AsyncClient,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
actor: User,
summarization_mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER,
message_buffer_limit: int = 10,
message_buffer_min: int = 4,
):
super().__init__(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
)
# TODO: Make this more general, factorable
# Summarizer settings
self.block_manager = block_manager
self.passage_manager = PassageManager() # TODO: pass this in
# TODO: This is not guaranteed to exist!
self.summary_block_label = "human"
self.summarizer = Summarizer(
mode=summarization_mode,
summarizer_agent=EphemeralAgent(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
),
message_buffer_limit=message_buffer_limit,
message_buffer_min=message_buffer_min,
)
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
async def step(self, input_message: UserMessage) -> List[Message]:
raise NotImplementedError("LowLatencyAgent does not have a synchronous step implemented currently.")
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
"""
Async generator that yields partial tokens as SSE events, handles tool calls,
and streams error messages if OpenAI API failures occur.
"""
input_message = self.pre_process_input_message(input_message=input_message)
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)]
in_memory_message_history = [input_message]
while True:
# Constantly pull down and integrate memory blocks
in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state)
# Convert Letta messages to OpenAI messages
openai_messages = convert_letta_messages_to_openai(in_context_messages)
openai_messages.extend(in_memory_message_history)
request = self._build_openai_request(openai_messages, agent_state)
# Execute the request
stream = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
streaming_interface = OpenAIChatCompletionsStreamingInterface(stream_pre_execution_message=True)
async for sse in streaming_interface.process(stream):
yield sse
# Process the AI response (buffered messages, tool execution, etc.)
continue_execution = await self._handle_ai_response(
streaming_interface, agent_state, in_memory_message_history, letta_message_db_queue
)
if not continue_execution:
break
# Rebuild context window
await self._rebuild_context_window(in_context_messages, letta_message_db_queue, agent_state)
yield "data: [DONE]\n\n"
async def _handle_ai_response(
self,
streaming_interface: OpenAIChatCompletionsStreamingInterface,
agent_state: AgentState,
in_memory_message_history: List[Dict[str, Any]],
letta_message_db_queue: List[Any],
) -> bool:
"""
Handles AI response processing, including buffering messages, detecting tool calls,
executing tools, and deciding whether to continue execution.
Returns:
bool: True if execution should continue, False if the step loop should terminate.
"""
# Handle assistant message buffering
if streaming_interface.content_buffer:
content = "".join(streaming_interface.content_buffer)
in_memory_message_history.append({"role": "assistant", "content": content})
assistant_msgs = create_assistant_messages_from_openai_response(
response_text=content,
agent_id=agent_state.id,
model=agent_state.llm_config.model,
actor=self.actor,
)
letta_message_db_queue.extend(assistant_msgs)
# Handle tool execution if a tool call occurred
if streaming_interface.tool_call_happened:
try:
tool_args = json.loads(streaming_interface.tool_call_args_str)
except json.JSONDecodeError:
tool_args = {}
tool_call_id = streaming_interface.tool_call_id or f"call_{uuid.uuid4().hex[:8]}"
assistant_tool_call_msg = AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id=tool_call_id,
function=ToolCallFunction(
name=streaming_interface.tool_call_name,
arguments=streaming_interface.tool_call_args_str,
),
)
],
)
in_memory_message_history.append(assistant_tool_call_msg.model_dump())
tool_result, function_call_success = await self._execute_tool(
tool_name=streaming_interface.tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
)
tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id)
in_memory_message_history.append(tool_message.model_dump())
heartbeat_user_message = UserMessage(
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user."
)
in_memory_message_history.append(heartbeat_user_message.model_dump())
tool_call_messages = create_tool_call_messages_from_openai_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=streaming_interface.tool_call_name,
function_arguments=tool_args,
tool_call_id=tool_call_id,
function_call_success=function_call_success,
function_response=tool_result,
actor=self.actor,
add_heartbeat_request_system_message=True,
)
letta_message_db_queue.extend(tool_call_messages)
# Continue execution by restarting the loop with updated context
return True
# Exit the loop if finish_reason_stop or no tool call occurred
return not streaming_interface.finish_reason_stop
async def _rebuild_context_window(
self, in_context_messages: List[Message], letta_message_db_queue: List[Message], agent_state: AgentState
) -> None:
new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor)
# TODO: Make this more general and configurable, less brittle
target_block = next(b for b in agent_state.memory.blocks if b.label == self.summary_block_label)
previous_summary = self.block_manager.get_block_by_id(block_id=target_block.id, actor=self.actor).value
new_in_context_messages, summary_str, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, previous_summary=previous_summary
)
if updated:
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=summary_str), actor=self.actor)
self.agent_manager.set_in_context_messages(
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=actor, agent_id=agent_id)
num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
tool_schemas = self._build_tool_schemas(agent_state)
tool_choice = "auto" if tool_schemas else None
openai_request = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=openai_messages,
tools=self._build_tool_schemas(agent_state),
tool_choice=tool_choice,
user=self.actor.id,
max_completion_tokens=agent_state.llm_config.max_tokens,
temperature=agent_state.llm_config.temperature,
stream=True,
)
return openai_request
def _build_tool_schemas(self, agent_state: AgentState, external_tools_only=True) -> List[Tool]:
if external_tools_only:
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
else:
tools = agent_state.tools
# TODO: Customize whether or not to have heartbeats, pre_exec_message, etc.
return [
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
for t in tools
]
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
"""
Executes a tool and returns (result, success_flag).
"""
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
return f"Tool not found: {tool_name}", False
try:
tool_result, _ = execute_external_tool(
agent_state=agent_state,
function_name=tool_name,
function_args=tool_args,
target_letta_tool=target_tool,
actor=self.actor,
allow_agent_state_modifications=False,
)
return tool_result, True
except Exception as e:
return f"Failed to call tool. Error: {e}", False