380 lines
17 KiB
Python
380 lines
17 KiB
Python
import json
|
|
import uuid
|
|
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
|
|
|
import openai
|
|
|
|
from letta.agents.base_agent import BaseAgent
|
|
from letta.agents.ephemeral_memory_agent import EphemeralMemoryAgent
|
|
from letta.constants import NON_USER_MSG_PREFIX
|
|
from letta.helpers.datetime_helpers import get_utc_time
|
|
from letta.helpers.tool_execution_helper import (
|
|
add_pre_execution_message,
|
|
enable_strict_mode,
|
|
execute_external_tool,
|
|
remove_request_heartbeat,
|
|
)
|
|
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
|
|
from letta.log import get_logger
|
|
from letta.orm.enums import ToolType
|
|
from letta.schemas.agent import AgentState
|
|
from letta.schemas.block import BlockUpdate
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.letta_response import LettaResponse
|
|
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
|
from letta.schemas.openai.chat_completion_request import (
|
|
AssistantMessage,
|
|
ChatCompletionRequest,
|
|
Tool,
|
|
ToolCall,
|
|
ToolCallFunction,
|
|
ToolMessage,
|
|
UserMessage,
|
|
)
|
|
from letta.schemas.user import User
|
|
from letta.server.rest_api.utils import (
|
|
convert_letta_messages_to_openai,
|
|
create_assistant_messages_from_openai_response,
|
|
create_input_messages,
|
|
create_letta_messages_from_llm_response,
|
|
)
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.block_manager import BlockManager
|
|
from letta.services.helpers.agent_manager_helper import compile_system_message
|
|
from letta.services.message_manager import MessageManager
|
|
from letta.services.passage_manager import PassageManager
|
|
from letta.services.summarizer.enums import SummarizationMode
|
|
from letta.utils import united_diff
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class VoiceAgent(BaseAgent):
|
|
"""
|
|
A function-calling loop for streaming OpenAI responses with tool execution.
|
|
This agent:
|
|
- Streams partial tokens in real-time for low-latency output.
|
|
- Detects tool calls and invokes external tools.
|
|
- Gracefully handles OpenAI API failures (429, etc.) and streams errors.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent_id: str,
|
|
openai_client: openai.AsyncClient,
|
|
message_manager: MessageManager,
|
|
agent_manager: AgentManager,
|
|
block_manager: BlockManager,
|
|
actor: User,
|
|
message_buffer_limit: int,
|
|
message_buffer_min: int,
|
|
summarization_mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER,
|
|
):
|
|
super().__init__(
|
|
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
|
)
|
|
|
|
# TODO: Make this more general, factorable
|
|
# Summarizer settings
|
|
self.block_manager = block_manager
|
|
self.passage_manager = PassageManager() # TODO: pass this in
|
|
# TODO: This is not guaranteed to exist!
|
|
self.summary_block_label = "human"
|
|
# self.summarizer = Summarizer(
|
|
# mode=summarization_mode,
|
|
# summarizer_agent=EphemeralAgent(
|
|
# agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
|
# ),
|
|
# message_buffer_limit=message_buffer_limit,
|
|
# message_buffer_min=message_buffer_min,
|
|
# )
|
|
self.message_buffer_limit = message_buffer_limit
|
|
# self.message_buffer_min = message_buffer_min
|
|
self.offline_memory_agent = EphemeralMemoryAgent(
|
|
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
|
)
|
|
|
|
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
|
raise NotImplementedError("LowLatencyAgent does not have a synchronous step implemented currently.")
|
|
|
|
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Main streaming loop that yields partial tokens.
|
|
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
|
"""
|
|
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
|
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
|
letta_message_db_queue = [create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=self.actor)]
|
|
in_memory_message_history = self.pre_process_input_message(input_messages)
|
|
|
|
# TODO: Define max steps here
|
|
for _ in range(max_steps):
|
|
# Rebuild memory each loop
|
|
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
|
openai_messages = convert_letta_messages_to_openai(in_context_messages)
|
|
openai_messages.extend(in_memory_message_history)
|
|
|
|
request = self._build_openai_request(openai_messages, agent_state)
|
|
|
|
stream = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
|
|
streaming_interface = OpenAIChatCompletionsStreamingInterface(stream_pre_execution_message=True)
|
|
|
|
# 1) Yield partial tokens from OpenAI
|
|
async for sse_chunk in streaming_interface.process(stream):
|
|
yield sse_chunk
|
|
|
|
# 2) Now handle the final AI response. This might yield more text (stalling, etc.)
|
|
should_continue = await self._handle_ai_response(
|
|
streaming_interface,
|
|
agent_state,
|
|
in_memory_message_history,
|
|
letta_message_db_queue,
|
|
)
|
|
|
|
if not should_continue:
|
|
break
|
|
|
|
# Rebuild context window if desired
|
|
await self._rebuild_context_window(in_context_messages, letta_message_db_queue, agent_state)
|
|
yield "data: [DONE]\n\n"
|
|
|
|
async def _handle_ai_response(
|
|
self,
|
|
streaming_interface: "OpenAIChatCompletionsStreamingInterface",
|
|
agent_state: AgentState,
|
|
in_memory_message_history: List[Dict[str, Any]],
|
|
letta_message_db_queue: List[Any],
|
|
) -> bool:
|
|
"""
|
|
Now that streaming is done, handle the final AI response.
|
|
This might yield additional SSE tokens if we do stalling.
|
|
At the end, set self._continue_execution accordingly.
|
|
"""
|
|
# 1. If we have any leftover content from partial stream, store it as an assistant message
|
|
if streaming_interface.content_buffer:
|
|
content = "".join(streaming_interface.content_buffer)
|
|
in_memory_message_history.append({"role": "assistant", "content": content})
|
|
|
|
assistant_msgs = create_assistant_messages_from_openai_response(
|
|
response_text=content,
|
|
agent_id=agent_state.id,
|
|
model=agent_state.llm_config.model,
|
|
actor=self.actor,
|
|
)
|
|
letta_message_db_queue.extend(assistant_msgs)
|
|
|
|
# 2. If a tool call was requested, handle it
|
|
if streaming_interface.tool_call_happened:
|
|
tool_call_name = streaming_interface.tool_call_name
|
|
tool_call_args_str = streaming_interface.tool_call_args_str or "{}"
|
|
try:
|
|
tool_args = json.loads(tool_call_args_str)
|
|
except json.JSONDecodeError:
|
|
tool_args = {}
|
|
|
|
tool_call_id = streaming_interface.tool_call_id or f"call_{uuid.uuid4().hex[:8]}"
|
|
assistant_tool_call_msg = AssistantMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ToolCall(
|
|
id=tool_call_id,
|
|
function=ToolCallFunction(
|
|
name=tool_call_name,
|
|
arguments=tool_call_args_str,
|
|
),
|
|
)
|
|
],
|
|
)
|
|
in_memory_message_history.append(assistant_tool_call_msg.model_dump())
|
|
|
|
tool_result, success_flag = await self._execute_tool(
|
|
tool_name=tool_call_name,
|
|
tool_args=tool_args,
|
|
agent_state=agent_state,
|
|
)
|
|
|
|
# 3. Provide function_call response back into the conversation
|
|
tool_message = ToolMessage(
|
|
content=json.dumps({"result": tool_result}),
|
|
tool_call_id=tool_call_id,
|
|
)
|
|
in_memory_message_history.append(tool_message.model_dump())
|
|
|
|
# 4. Insert heartbeat message for follow-up
|
|
heartbeat_user_message = UserMessage(
|
|
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user."
|
|
)
|
|
in_memory_message_history.append(heartbeat_user_message.model_dump())
|
|
|
|
# 5. Also store in DB
|
|
tool_call_messages = create_letta_messages_from_llm_response(
|
|
agent_id=agent_state.id,
|
|
model=agent_state.llm_config.model,
|
|
function_name=tool_call_name,
|
|
function_arguments=tool_args,
|
|
tool_call_id=tool_call_id,
|
|
function_call_success=success_flag,
|
|
function_response=tool_result,
|
|
actor=self.actor,
|
|
add_heartbeat_request_system_message=True,
|
|
)
|
|
letta_message_db_queue.extend(tool_call_messages)
|
|
|
|
# Because we have new data, we want to continue the while-loop in `step_stream`
|
|
return True
|
|
else:
|
|
# If we got here, there's no tool call. If finish_reason_stop => done
|
|
return not streaming_interface.finish_reason_stop
|
|
|
|
async def _rebuild_context_window(
|
|
self, in_context_messages: List[Message], letta_message_db_queue: List[Message], agent_state: AgentState
|
|
) -> None:
|
|
new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor)
|
|
new_in_context_messages = in_context_messages + new_letta_messages
|
|
|
|
if len(new_in_context_messages) > self.message_buffer_limit:
|
|
cutoff = len(new_in_context_messages) - self.message_buffer_limit
|
|
new_in_context_messages = [new_in_context_messages[0]] + new_in_context_messages[cutoff:]
|
|
|
|
self.agent_manager.set_in_context_messages(
|
|
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
|
)
|
|
|
|
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
|
# Refresh memory
|
|
# TODO: This only happens for the summary block
|
|
# TODO: We want to extend this refresh to be general, and stick it in agent_manager
|
|
for i, b in enumerate(agent_state.memory.blocks):
|
|
if b.label == self.summary_block_label:
|
|
agent_state.memory.blocks[i] = self.block_manager.get_block_by_id(block_id=b.id, actor=self.actor)
|
|
break
|
|
|
|
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
|
curr_system_message = in_context_messages[0]
|
|
curr_memory_str = agent_state.memory.compile()
|
|
curr_system_message_text = curr_system_message.content[0].text
|
|
if curr_memory_str in curr_system_message_text:
|
|
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
|
logger.debug(
|
|
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
|
)
|
|
return in_context_messages
|
|
|
|
memory_edit_timestamp = get_utc_time()
|
|
|
|
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
|
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
|
|
|
new_system_message_str = compile_system_message(
|
|
system_prompt=agent_state.system,
|
|
in_context_memory=agent_state.memory,
|
|
in_context_memory_last_edit=memory_edit_timestamp,
|
|
previous_message_count=num_messages,
|
|
archival_memory_size=num_archival_memories,
|
|
)
|
|
|
|
diff = united_diff(curr_system_message_text, new_system_message_str)
|
|
if len(diff) > 0:
|
|
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
|
|
|
new_system_message = self.message_manager.update_message_by_id(
|
|
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
|
)
|
|
|
|
# Skip pulling down the agent's memory again to save on a db call
|
|
return [new_system_message] + in_context_messages[1:]
|
|
|
|
else:
|
|
return in_context_messages
|
|
|
|
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
|
|
tool_schemas = self._build_tool_schemas(agent_state)
|
|
tool_choice = "auto" if tool_schemas else None
|
|
|
|
openai_request = ChatCompletionRequest(
|
|
model=agent_state.llm_config.model,
|
|
messages=openai_messages,
|
|
tools=self._build_tool_schemas(agent_state),
|
|
tool_choice=tool_choice,
|
|
user=self.actor.id,
|
|
max_completion_tokens=agent_state.llm_config.max_tokens,
|
|
temperature=agent_state.llm_config.temperature,
|
|
stream=True,
|
|
)
|
|
return openai_request
|
|
|
|
def _build_tool_schemas(self, agent_state: AgentState, external_tools_only=True) -> List[Tool]:
|
|
if external_tools_only:
|
|
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
|
|
else:
|
|
tools = agent_state.tools
|
|
|
|
# Special tool state
|
|
recall_memory_utterance_description = (
|
|
"A lengthier message to be uttered while your memories of the current conversation are being re-contextualized."
|
|
"You should stall naturally and show the user you're thinking hard. The main thing is to not leave the user in silence."
|
|
"You MUST also include punctuation at the end of this message."
|
|
)
|
|
recall_memory_json = Tool(
|
|
type="function",
|
|
function=enable_strict_mode(
|
|
add_pre_execution_message(
|
|
{
|
|
"name": "recall_memory",
|
|
"description": "Retrieve relevant information from memory based on a given query. Use when you don't remember the answer to a question.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "A description of what the model is trying to recall from memory.",
|
|
}
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
description=recall_memory_utterance_description,
|
|
)
|
|
),
|
|
)
|
|
|
|
# TODO: Customize whether or not to have heartbeats, pre_exec_message, etc.
|
|
return [recall_memory_json] + [
|
|
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
|
|
for t in tools
|
|
]
|
|
|
|
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
|
"""
|
|
Executes a tool and returns (result, success_flag).
|
|
"""
|
|
# Special memory case
|
|
if tool_name == "recall_memory":
|
|
# TODO: Make this safe
|
|
await self._recall_memory(tool_args["query"], agent_state)
|
|
return f"Successfully recalled memory and populated {self.summary_block_label} block.", True
|
|
else:
|
|
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
|
if not target_tool:
|
|
return f"Tool not found: {tool_name}", False
|
|
|
|
try:
|
|
tool_result, _ = execute_external_tool(
|
|
agent_state=agent_state,
|
|
function_name=tool_name,
|
|
function_args=tool_args,
|
|
target_letta_tool=target_tool,
|
|
actor=self.actor,
|
|
allow_agent_state_modifications=False,
|
|
)
|
|
return tool_result, True
|
|
except Exception as e:
|
|
return f"Failed to call tool. Error: {e}", False
|
|
|
|
async def _recall_memory(self, query, agent_state: AgentState) -> None:
|
|
results = await self.offline_memory_agent.step([MessageCreate(role="user", content=[TextContent(text=query)])])
|
|
target_block = next(b for b in agent_state.memory.blocks if b.label == self.summary_block_label)
|
|
self.block_manager.update_block(
|
|
block_id=target_block.id, block_update=BlockUpdate(value=results[0].content[0].text), actor=self.actor
|
|
)
|