feat: Low Latency Agent (#1157)
This commit is contained in:
0
letta/interfaces/__init__.py
Normal file
0
letta/interfaces/__init__.py
Normal file
109
letta/interfaces/openai_chat_completions_streaming_interface.py
Normal file
109
letta/interfaces/openai_chat_completions_streaming_interface.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
|
||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.interfaces.utils import _format_sse_chunk
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
|
||||
|
||||
class OpenAIChatCompletionsStreamingInterface:
|
||||
"""
|
||||
Encapsulates the logic for streaming responses from OpenAI.
|
||||
This class handles parsing of partial tokens, pre-execution messages,
|
||||
and detection of tool call events.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_pre_execution_message: bool = True):
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.stream_pre_execution_message: bool = stream_pre_execution_message
|
||||
|
||||
self.current_parsed_json_result: Dict[str, Any] = {}
|
||||
self.content_buffer: List[str] = []
|
||||
self.tool_call_happened: bool = False
|
||||
self.finish_reason_stop: bool = False
|
||||
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_args_str: str = ""
|
||||
self.tool_call_id: Optional[str] = None
|
||||
|
||||
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
It also collects tokens and detects if a tool call is triggered.
|
||||
"""
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
async for sse_chunk in self._process_content(delta, chunk):
|
||||
yield sse_chunk
|
||||
|
||||
async for sse_chunk in self._process_tool_calls(delta, chunk):
|
||||
yield sse_chunk
|
||||
|
||||
if self._handle_finish_reason(finish_reason):
|
||||
break
|
||||
|
||||
async def _process_content(self, delta: ChoiceDelta, chunk: ChatCompletionChunk) -> AsyncGenerator[str, None]:
|
||||
"""Processes regular content tokens and streams them."""
|
||||
if delta.content:
|
||||
self.content_buffer.append(delta.content)
|
||||
yield _format_sse_chunk(chunk)
|
||||
|
||||
async def _process_tool_calls(self, delta: ChoiceDelta, chunk: ChatCompletionChunk) -> AsyncGenerator[str, None]:
|
||||
"""Handles tool call initiation and streaming of pre-execution messages."""
|
||||
if not delta.tool_calls:
|
||||
return
|
||||
|
||||
tool_call = delta.tool_calls[0]
|
||||
self._update_tool_call_info(tool_call)
|
||||
|
||||
if self.stream_pre_execution_message and tool_call.function.arguments:
|
||||
self.tool_call_args_str += tool_call.function.arguments
|
||||
async for sse_chunk in self._stream_pre_execution_message(chunk, tool_call):
|
||||
yield sse_chunk
|
||||
|
||||
def _update_tool_call_info(self, tool_call: Any) -> None:
|
||||
"""Updates tool call-related attributes."""
|
||||
if tool_call.function.name:
|
||||
self.tool_call_name = tool_call.function.name
|
||||
if tool_call.id:
|
||||
self.tool_call_id = tool_call.id
|
||||
|
||||
async def _stream_pre_execution_message(self, chunk: ChatCompletionChunk, tool_call: Any) -> AsyncGenerator[str, None]:
|
||||
"""Parses and streams pre-execution messages if they have changed."""
|
||||
parsed_args = self.optimistic_json_parser.parse(self.tool_call_args_str)
|
||||
|
||||
if parsed_args.get(PRE_EXECUTION_MESSAGE_ARG) and self.current_parsed_json_result.get(PRE_EXECUTION_MESSAGE_ARG) != parsed_args.get(
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
):
|
||||
if parsed_args != self.current_parsed_json_result:
|
||||
self.current_parsed_json_result = parsed_args
|
||||
synthetic_chunk = ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created,
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(content=tool_call.function.arguments, role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield _format_sse_chunk(synthetic_chunk)
|
||||
|
||||
def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool:
|
||||
"""Handles the finish reason and determines if streaming should stop."""
|
||||
if finish_reason == "tool_calls":
|
||||
self.tool_call_happened = True
|
||||
return True
|
||||
if finish_reason == "stop":
|
||||
self.finish_reason_stop = True
|
||||
return True
|
||||
return False
|
||||
11
letta/interfaces/utils.py
Normal file
11
letta/interfaces/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import json
|
||||
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
|
||||
def _format_sse_error(error_payload: dict) -> str:
|
||||
return f"data: {json.dumps(error_payload)}\n\n"
|
||||
|
||||
|
||||
def _format_sse_chunk(chunk: ChatCompletionChunk) -> str:
|
||||
return f"data: {chunk.model_dump_json()}\n\n"
|
||||
286
letta/low_latency_agent.py
Normal file
286
letta/low_latency_agent.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
import openai
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import (
|
||||
add_pre_execution_message,
|
||||
enable_strict_mode,
|
||||
execute_external_tool,
|
||||
remove_request_heartbeat,
|
||||
)
|
||||
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
AssistantMessage,
|
||||
ChatCompletionRequest,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.utils import (
|
||||
convert_letta_messages_to_openai,
|
||||
create_assistant_messages_from_openai_response,
|
||||
create_tool_call_messages_from_openai_response,
|
||||
create_user_message,
|
||||
)
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.utils import united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LowLatencyAgent:
|
||||
"""
|
||||
A function-calling loop for streaming OpenAI responses with tool execution.
|
||||
This agent:
|
||||
- Streams partial tokens in real-time for low-latency output.
|
||||
- Detects tool calls and invokes external tools.
|
||||
- Gracefully handles OpenAI API failures (429, etc.) and streams errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
openai_client: openai.AsyncClient,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
actor: User,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.openai_client = openai_client
|
||||
|
||||
# DB access related fields
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.actor = actor
|
||||
|
||||
# Internal conversation state
|
||||
self.optimistic_json_parser = OptimisticJSONParser(strict=True)
|
||||
self.current_parsed_json_result: Dict[str, Any] = {}
|
||||
|
||||
async def step(self, input_message: Dict[str, str]) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Async generator that yields partial tokens as SSE events, handles tool calls,
|
||||
and streams error messages if OpenAI API failures occur.
|
||||
"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
|
||||
letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)]
|
||||
in_memory_message_history = [input_message]
|
||||
|
||||
while True:
|
||||
# Build context and request
|
||||
openai_messages = self._build_context_window(in_memory_message_history, agent_state)
|
||||
request = self._build_openai_request(openai_messages, agent_state)
|
||||
|
||||
# Execute the request
|
||||
stream = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
|
||||
streaming_interface = OpenAIChatCompletionsStreamingInterface(stream_pre_execution_message=True)
|
||||
|
||||
async for sse in streaming_interface.process(stream):
|
||||
yield sse
|
||||
|
||||
# Process the AI response (buffered messages, tool execution, etc.)
|
||||
continue_execution = await self.handle_ai_response(
|
||||
streaming_interface, agent_state, in_memory_message_history, letta_message_db_queue
|
||||
)
|
||||
|
||||
if not continue_execution:
|
||||
break
|
||||
|
||||
# Persist messages to the database asynchronously
|
||||
await run_in_threadpool(
|
||||
self.agent_manager.append_to_in_context_messages,
|
||||
letta_message_db_queue,
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def handle_ai_response(
|
||||
self,
|
||||
streaming_interface: OpenAIChatCompletionsStreamingInterface,
|
||||
agent_state: AgentState,
|
||||
in_memory_message_history: List[Dict[str, Any]],
|
||||
letta_message_db_queue: List[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Handles AI response processing, including buffering messages, detecting tool calls,
|
||||
executing tools, and deciding whether to continue execution.
|
||||
|
||||
Returns:
|
||||
bool: True if execution should continue, False if the step loop should terminate.
|
||||
"""
|
||||
# Handle assistant message buffering
|
||||
if streaming_interface.content_buffer:
|
||||
content = "".join(streaming_interface.content_buffer)
|
||||
in_memory_message_history.append({"role": "assistant", "content": content})
|
||||
|
||||
assistant_msgs = create_assistant_messages_from_openai_response(
|
||||
response_text=content,
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
actor=self.actor,
|
||||
)
|
||||
letta_message_db_queue.extend(assistant_msgs)
|
||||
|
||||
# Handle tool execution if a tool call occurred
|
||||
if streaming_interface.tool_call_happened:
|
||||
try:
|
||||
tool_args = json.loads(streaming_interface.tool_call_args_str)
|
||||
except json.JSONDecodeError:
|
||||
tool_args = {}
|
||||
|
||||
tool_call_id = streaming_interface.tool_call_id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
assistant_tool_call_msg = AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=tool_call_id,
|
||||
function=ToolCallFunction(
|
||||
name=streaming_interface.tool_call_name,
|
||||
arguments=streaming_interface.tool_call_args_str,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
in_memory_message_history.append(assistant_tool_call_msg.model_dump())
|
||||
|
||||
tool_result, function_call_success = await self._execute_tool(
|
||||
tool_name=streaming_interface.tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
|
||||
tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id)
|
||||
in_memory_message_history.append(tool_message.model_dump())
|
||||
|
||||
heartbeat_user_message = UserMessage(
|
||||
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user."
|
||||
)
|
||||
in_memory_message_history.append(heartbeat_user_message.model_dump())
|
||||
|
||||
tool_call_messages = create_tool_call_messages_from_openai_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=streaming_interface.tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=function_call_success,
|
||||
function_response=tool_result,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=True,
|
||||
)
|
||||
letta_message_db_queue.extend(tool_call_messages)
|
||||
|
||||
# Continue execution by restarting the loop with updated context
|
||||
return True
|
||||
|
||||
# Exit the loop if finish_reason_stop or no tool call occurred
|
||||
return not streaming_interface.finish_reason_stop
|
||||
|
||||
def _build_context_window(self, in_memory_message_history: List[Dict[str, Any]], agent_state: AgentState) -> List[Dict]:
|
||||
# Build in_context_messages
|
||||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state)
|
||||
|
||||
# Convert Letta messages to OpenAI messages
|
||||
openai_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
openai_messages.extend(in_memory_message_history)
|
||||
return openai_messages
|
||||
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
if curr_memory_str in curr_system_message.text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message.text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
|
||||
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
|
||||
tool_schemas = self._build_tool_schemas(agent_state)
|
||||
tool_choice = "auto" if tool_schemas else None
|
||||
|
||||
openai_request = ChatCompletionRequest(
|
||||
model=agent_state.llm_config.model,
|
||||
messages=openai_messages,
|
||||
tools=self._build_tool_schemas(agent_state),
|
||||
tool_choice=tool_choice,
|
||||
user=self.actor.id,
|
||||
max_completion_tokens=agent_state.llm_config.max_tokens,
|
||||
temperature=agent_state.llm_config.temperature,
|
||||
stream=True,
|
||||
)
|
||||
return openai_request
|
||||
|
||||
def _build_tool_schemas(self, agent_state: AgentState, external_tools_only=True) -> List[Tool]:
|
||||
if external_tools_only:
|
||||
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
|
||||
else:
|
||||
tools = agent_state.tools
|
||||
|
||||
# TODO: Customize whether or not to have heartbeats, pre_exec_message, etc.
|
||||
return [
|
||||
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
|
||||
for t in tools
|
||||
]
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> (str, bool):
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||||
if not target_tool:
|
||||
return f"Tool not found: {tool_name}", False
|
||||
|
||||
try:
|
||||
tool_result, _ = execute_external_tool(
|
||||
agent_state=agent_state,
|
||||
function_name=tool_name,
|
||||
function_args=tool_args,
|
||||
target_letta_tool=target_tool,
|
||||
actor=self.actor,
|
||||
allow_agent_state_modifications=False,
|
||||
)
|
||||
return tool_result, True
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
@@ -1,42 +1,14 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from letta.constants import LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.helpers.tool_execution_helper import (
|
||||
add_pre_execution_message,
|
||||
enable_strict_mode,
|
||||
execute_external_tool,
|
||||
remove_request_heartbeat,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
AssistantMessage,
|
||||
ChatCompletionRequest,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.utils import (
|
||||
convert_letta_messages_to_openai,
|
||||
create_assistant_messages_from_openai_response,
|
||||
create_tool_call_messages_from_openai_response,
|
||||
create_user_message,
|
||||
get_letta_server,
|
||||
get_messages_from_completion_request,
|
||||
)
|
||||
from letta.low_latency_agent import LowLatencyAgent
|
||||
from letta.server.rest_api.utils import get_letta_server, get_messages_from_completion_request
|
||||
from letta.settings import model_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -72,42 +44,14 @@ async def create_voice_chat_completions(
|
||||
if agent_id is None:
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
|
||||
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.llm_config.model_endpoint_type != "openai":
|
||||
raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
|
||||
# agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
# if agent_state.llm_config.model_endpoint_type != "openai":
|
||||
# raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
|
||||
|
||||
# Convert Letta messages to OpenAI messages
|
||||
in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)
|
||||
openai_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
|
||||
# Also parse user input from completion_request and append
|
||||
# Also parse the user's new input
|
||||
input_message = get_messages_from_completion_request(completion_request)[-1]
|
||||
openai_messages.append(input_message)
|
||||
|
||||
# Tools we allow this agent to call
|
||||
tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
|
||||
|
||||
# Initial request
|
||||
openai_request = ChatCompletionRequest(
|
||||
model=agent_state.llm_config.model,
|
||||
messages=openai_messages,
|
||||
# TODO: This nested thing here is so ugly, need to refactor
|
||||
tools=(
|
||||
[
|
||||
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
|
||||
for t in tools
|
||||
]
|
||||
if tools
|
||||
else None
|
||||
),
|
||||
tool_choice="auto",
|
||||
user=user_id,
|
||||
max_completion_tokens=agent_state.llm_config.max_tokens,
|
||||
temperature=agent_state.llm_config.temperature,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Create the OpenAI async client
|
||||
# Create OpenAI async client
|
||||
client = openai.AsyncClient(
|
||||
api_key=model_settings.openai_api_key,
|
||||
max_retries=0,
|
||||
@@ -122,194 +66,14 @@ async def create_voice_chat_completions(
|
||||
),
|
||||
)
|
||||
|
||||
# The messages we want to persist to the Letta agent
|
||||
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
|
||||
message_db_queue = [user_message]
|
||||
# Instantiate our LowLatencyAgent
|
||||
agent = LowLatencyAgent(
|
||||
agent_id=agent_id,
|
||||
openai_client=client,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def event_stream():
|
||||
"""
|
||||
A function-calling loop:
|
||||
- We stream partial tokens.
|
||||
- If we detect a tool call (finish_reason="tool_calls"), we parse it,
|
||||
add two messages to the conversation:
|
||||
(a) assistant message with tool_calls referencing the same ID
|
||||
(b) a tool message referencing that ID, containing the tool result.
|
||||
- Re-invoke the OpenAI request with updated conversation, streaming again.
|
||||
- End when finish_reason="stop" or no more tool calls.
|
||||
"""
|
||||
|
||||
# We'll keep updating this conversation in a loop
|
||||
conversation = openai_messages[:]
|
||||
|
||||
while True:
|
||||
# Make the streaming request to OpenAI
|
||||
stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True))
|
||||
|
||||
content_buffer = []
|
||||
tool_call_name = None
|
||||
tool_call_args_str = ""
|
||||
tool_call_id = None
|
||||
tool_call_happened = False
|
||||
finish_reason_stop = False
|
||||
optimistic_json_parser = OptimisticJSONParser(strict=True)
|
||||
current_parsed_json_result = {}
|
||||
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
finish_reason = choice.finish_reason # "tool_calls", "stop", or None
|
||||
|
||||
if delta.content:
|
||||
content_buffer.append(delta.content)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# CASE B: Partial tool call info
|
||||
if delta.tool_calls:
|
||||
# Typically there's only one in delta.tool_calls
|
||||
tc = delta.tool_calls[0]
|
||||
if tc.function.name:
|
||||
tool_call_name = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_call_args_str += tc.function.arguments
|
||||
|
||||
# See if we can stream out the pre-execution message
|
||||
parsed_args = optimistic_json_parser.parse(tool_call_args_str)
|
||||
if parsed_args.get(
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
) and current_parsed_json_result.get( # Ensure key exists and is not None/empty
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
) != parsed_args.get(
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
):
|
||||
# Only stream if there's something new to stream
|
||||
# We do this way to avoid hanging JSON at the end of the stream, e.g. '}'
|
||||
if parsed_args != current_parsed_json_result:
|
||||
current_parsed_json_result = parsed_args
|
||||
synthetic_chunk = ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created,
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=tc.function.arguments, role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {synthetic_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# We might generate a unique ID for the tool call
|
||||
if tc.id:
|
||||
tool_call_id = tc.id
|
||||
|
||||
# Check finish_reason
|
||||
if finish_reason == "tool_calls":
|
||||
tool_call_happened = True
|
||||
break
|
||||
elif finish_reason == "stop":
|
||||
finish_reason_stop = True
|
||||
break
|
||||
|
||||
if content_buffer:
|
||||
# We treat that partial text as an assistant message
|
||||
content = "".join(content_buffer)
|
||||
conversation.append({"role": "assistant", "content": content})
|
||||
|
||||
# Create an assistant message here to persist later
|
||||
assistant_messages = create_assistant_messages_from_openai_response(
|
||||
response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor
|
||||
)
|
||||
message_db_queue.extend(assistant_messages)
|
||||
|
||||
if tool_call_happened:
|
||||
# Parse the tool call arguments
|
||||
try:
|
||||
tool_args = json.loads(tool_call_args_str)
|
||||
except json.JSONDecodeError:
|
||||
tool_args = {}
|
||||
|
||||
if not tool_call_id:
|
||||
# If no tool_call_id given by the model, generate one
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 1) Insert the "assistant" message with the tool_calls field
|
||||
# referencing the same tool_call_id
|
||||
assistant_tool_call_msg = AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))],
|
||||
)
|
||||
|
||||
conversation.append(assistant_tool_call_msg.model_dump())
|
||||
|
||||
# 2) Execute the tool
|
||||
target_tool = next((x for x in tools if x.name == tool_call_name), None)
|
||||
if not target_tool:
|
||||
# Tool not found, handle error
|
||||
yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n"
|
||||
break
|
||||
|
||||
try:
|
||||
tool_result, _ = execute_external_tool(
|
||||
agent_state=agent_state,
|
||||
function_name=tool_call_name,
|
||||
function_args=tool_args,
|
||||
target_letta_tool=target_tool,
|
||||
actor=actor,
|
||||
allow_agent_state_modifications=False,
|
||||
)
|
||||
function_call_success = True
|
||||
except Exception as e:
|
||||
tool_result = f"Failed to call tool. Error: {e}"
|
||||
function_call_success = False
|
||||
|
||||
# 3) Insert the "tool" message referencing the same tool_call_id
|
||||
tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id)
|
||||
|
||||
conversation.append(tool_message.model_dump())
|
||||
|
||||
# 4) Add a user message prompting the tool call result summarization
|
||||
heartbeat_user_message = UserMessage(
|
||||
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.",
|
||||
)
|
||||
conversation.append(heartbeat_user_message.model_dump())
|
||||
|
||||
# Now, re-invoke OpenAI with the updated conversation
|
||||
openai_request.messages = conversation
|
||||
|
||||
# Create a tool call message and append to message_db_queue
|
||||
tool_call_messages = create_tool_call_messages_from_openai_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=function_call_success,
|
||||
function_response=tool_result,
|
||||
actor=actor,
|
||||
add_heartbeat_request_system_message=True,
|
||||
)
|
||||
message_db_queue.extend(tool_call_messages)
|
||||
|
||||
continue # Start the while loop again
|
||||
|
||||
if finish_reason_stop:
|
||||
break
|
||||
|
||||
# If we reach here, no tool call, no "stop", but we've ended streaming
|
||||
# Possibly a model error or some other finish reason. We'll just end.
|
||||
break
|
||||
|
||||
await run_in_threadpool(
|
||||
server.agent_manager.append_to_in_context_messages,
|
||||
message_db_queue,
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
# Return the streaming generator
|
||||
return StreamingResponse(agent.step(input_message=input_message), media_type="text/event-stream")
|
||||
|
||||
@@ -13,7 +13,7 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REQ_HEARTBEAT_MESSAGE
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
@@ -216,9 +216,10 @@ def create_tool_call_messages_from_openai_response(
|
||||
messages.append(tool_message)
|
||||
|
||||
if add_heartbeat_request_system_message:
|
||||
text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE
|
||||
heartbeat_system_message = Message(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=get_heartbeat(REQ_HEARTBEAT_MESSAGE))],
|
||||
content=[TextContent(text=get_heartbeat(text_content))],
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
|
||||
@@ -529,6 +529,7 @@ class AgentManager:
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
# TODO: This seems kind of silly, why not just update the message?
|
||||
message = self.message_manager.create_message(message, actor=actor)
|
||||
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
|
||||
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@@ -153,7 +153,7 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
|
||||
@pytest.mark.parametrize("message", ["What's the weather in SF?"])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice"])
|
||||
async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
@@ -163,8 +163,7 @@ async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
|
||||
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
||||
print(chunk)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user