chore: migrate package name to letta (#1775)
Co-authored-by: Charles Packer <packercharles@gmail.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
396
letta/streaming_interface.py
Normal file
396
letta/streaming_interface.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
# from colorama import Fore, Style, init
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markup import escape
|
||||
|
||||
from letta.interface import CLIInterface
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionChunkResponse,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
|
||||
# init(autoreset=True)
|
||||
|
||||
# DEBUG = True # puts full message outputs in the terminal
|
||||
DEBUG = False # only dumps important messages in the terminal
|
||||
|
||||
STRIP_UI = False
|
||||
|
||||
|
||||
class AgentChunkStreamingInterface(ABC):
|
||||
"""Interfaces handle Letta-related events (observer pattern)
|
||||
|
||||
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta receives a user message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta uses send_message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_start(self):
|
||||
"""Any setup required before streaming begins"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_end(self):
|
||||
"""Any cleanup required after streaming ends"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
"""Version of the CLI interface that attaches to a stream generator and prints along the way.
|
||||
|
||||
When a chunk is received, we write the delta to the buffer. If the buffer type has changed,
|
||||
we write out a newline + set the formatting for the new line.
|
||||
|
||||
The two buffer types are:
|
||||
(1) content (inner thoughts)
|
||||
(2) tool_calls (function calling)
|
||||
|
||||
NOTE: this assumes that the deltas received in the chunks are in-order, e.g.
|
||||
that once 'content' deltas stop streaming, they won't be received again. See notes
|
||||
on alternative version of the StreamingCLIInterface that does not have this same problem below:
|
||||
|
||||
An alternative implementation could instead maintain the partial message state, and on each
|
||||
process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen.
|
||||
"""
|
||||
|
||||
# CLIInterface is static/stateless
|
||||
nonstreaming_interface = CLIInterface()
|
||||
|
||||
def __init__(self):
|
||||
"""The streaming CLI interface state for determining which buffer is currently being written to"""
|
||||
|
||||
self.streaming_buffer_type = None
|
||||
|
||||
def _flush(self):
|
||||
pass
|
||||
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime):
|
||||
assert len(chunk.choices) == 1, chunk
|
||||
|
||||
message_delta = chunk.choices[0].delta
|
||||
|
||||
# Starting a new buffer line
|
||||
if not self.streaming_buffer_type:
|
||||
assert not (
|
||||
message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
|
||||
), f"Error: got both content and tool_calls in message stream\n{message_delta}"
|
||||
|
||||
if message_delta.content is not None:
|
||||
# Write out the prefix for inner thoughts
|
||||
print("Inner thoughts: ", end="", flush=True)
|
||||
elif message_delta.tool_calls is not None:
|
||||
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
|
||||
# Write out the prefix for function calling
|
||||
print("Calling function: ", end="", flush=True)
|
||||
|
||||
# Potentially switch/flush a buffer line
|
||||
else:
|
||||
pass
|
||||
|
||||
# Write out the delta
|
||||
if message_delta.content is not None:
|
||||
if self.streaming_buffer_type and self.streaming_buffer_type != "content":
|
||||
print()
|
||||
self.streaming_buffer_type = "content"
|
||||
|
||||
# Simple, just write out to the buffer
|
||||
print(message_delta.content, end="", flush=True)
|
||||
|
||||
elif message_delta.tool_calls is not None:
|
||||
if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls":
|
||||
print()
|
||||
self.streaming_buffer_type = "tool_calls"
|
||||
|
||||
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
|
||||
function_call = message_delta.tool_calls[0].function
|
||||
|
||||
# Slightly more complex - want to write parameters in a certain way (paren-style)
|
||||
# function_name(function_args)
|
||||
if function_call and function_call.name:
|
||||
# NOTE: need to account for closing the brace later
|
||||
print(f"{function_call.name}(", end="", flush=True)
|
||||
if function_call and function_call.arguments:
|
||||
print(function_call.arguments, end="", flush=True)
|
||||
|
||||
def stream_start(self):
|
||||
# should be handled by stream_end(), but just in case
|
||||
self.streaming_buffer_type = None
|
||||
|
||||
def stream_end(self):
|
||||
if self.streaming_buffer_type is not None:
|
||||
# TODO: should have a separate self.tool_call_open_paren flag
|
||||
if self.streaming_buffer_type == "tool_calls":
|
||||
print(")", end="", flush=True)
|
||||
|
||||
print() # newline to move the cursor
|
||||
self.streaming_buffer_type = None # reset buffer tracker
|
||||
|
||||
@staticmethod
|
||||
def important_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg)
|
||||
|
||||
@staticmethod
|
||||
def warning_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg)
|
||||
|
||||
@staticmethod
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def assistant_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def memory_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def system_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def print_messages(message_sequence: List[Message], dump=False):
|
||||
StreamingCLIInterface.nonstreaming_interface(message_sequence, dump)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_simple(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_raw(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def step_yield():
|
||||
pass
|
||||
|
||||
|
||||
class AgentRefreshStreamingInterface(ABC):
|
||||
"""Same as the ChunkStreamingInterface, but
|
||||
|
||||
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta receives a user message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta uses send_message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Letta calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_refresh(self, response: ChatCompletionResponse):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_start(self):
|
||||
"""Any setup required before streaming begins"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_end(self):
|
||||
"""Any cleanup required after streaming ends"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def toggle_streaming(self, on: bool):
|
||||
"""Toggle streaming on/off (off = regular CLI interface)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
|
||||
"""Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step.
|
||||
|
||||
We maintain the partial message state in the interface state, and on each
|
||||
process chunk we:
|
||||
(1) update the partial message state,
|
||||
(2) refresh/rewrite the state to the screen.
|
||||
"""
|
||||
|
||||
nonstreaming_interface = CLIInterface
|
||||
|
||||
def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True):
|
||||
"""Initialize the streaming CLI interface state."""
|
||||
self.console = Console()
|
||||
|
||||
# Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates
|
||||
self.live = Live("", console=self.console, refresh_per_second=10)
|
||||
# self.live.start() # Start the Live display context and keep it running
|
||||
|
||||
# Use italics / emoji?
|
||||
self.fancy = fancy
|
||||
|
||||
self.streaming = True
|
||||
self.separate_send_message = separate_send_message
|
||||
self.disable_inner_mono_call = disable_inner_mono_call
|
||||
|
||||
def toggle_streaming(self, on: bool):
|
||||
self.streaming = on
|
||||
if on:
|
||||
self.separate_send_message = True
|
||||
self.disable_inner_mono_call = True
|
||||
else:
|
||||
self.separate_send_message = False
|
||||
self.disable_inner_mono_call = False
|
||||
|
||||
def update_output(self, content: str):
|
||||
"""Update the displayed output with new content."""
|
||||
# We use the `Live` object's update mechanism to refresh content without clearing the console
|
||||
if not self.fancy:
|
||||
content = escape(content)
|
||||
self.live.update(self.console.render_str(content), refresh=True)
|
||||
|
||||
def process_refresh(self, response: ChatCompletionResponse):
|
||||
"""Process the response to rewrite the current output buffer."""
|
||||
if not response.choices:
|
||||
self.update_output("💭 [italic]...[/italic]")
|
||||
return # Early exit if there are no choices
|
||||
|
||||
choice = response.choices[0]
|
||||
inner_thoughts = choice.message.content if choice.message.content else ""
|
||||
tool_calls = choice.message.tool_calls if choice.message.tool_calls else []
|
||||
|
||||
if self.fancy:
|
||||
message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else ""
|
||||
else:
|
||||
message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else ""
|
||||
|
||||
if tool_calls:
|
||||
function_call = tool_calls[0].function
|
||||
function_name = function_call.name # Function name, can be an empty string
|
||||
function_args = function_call.arguments # Function arguments, can be an empty string
|
||||
if message_string:
|
||||
message_string += "\n"
|
||||
# special case here for send_message
|
||||
if self.separate_send_message and function_name == "send_message":
|
||||
try:
|
||||
message = json.loads(function_args)["message"]
|
||||
except:
|
||||
prefix = '{\n "message": "'
|
||||
if len(function_args) < len(prefix):
|
||||
message = "..."
|
||||
elif function_args.startswith(prefix):
|
||||
message = function_args[len(prefix) :]
|
||||
else:
|
||||
message = function_args
|
||||
message_string += f"🤖 [bold yellow]{message}[/bold yellow]"
|
||||
else:
|
||||
message_string += f"{function_name}({function_args})"
|
||||
|
||||
self.update_output(message_string)
|
||||
|
||||
def stream_start(self):
|
||||
if self.streaming:
|
||||
print()
|
||||
self.live.start() # Start the Live display context and keep it running
|
||||
self.update_output("💭 [italic]...[/italic]")
|
||||
|
||||
def stream_end(self):
|
||||
if self.streaming:
|
||||
if self.live.is_started:
|
||||
self.live.stop()
|
||||
print()
|
||||
self.live = Live("", console=self.console, refresh_per_second=10)
|
||||
|
||||
@staticmethod
|
||||
def important_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface.important_message(msg)
|
||||
|
||||
@staticmethod
|
||||
def warning_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface.warning_message(msg)
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
if self.disable_inner_mono_call:
|
||||
return
|
||||
StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
if self.separate_send_message:
|
||||
return
|
||||
StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def memory_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def system_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def print_messages(message_sequence: List[Message], dump=False):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_simple(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_raw(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def step_yield():
|
||||
pass
|
||||
Reference in New Issue
Block a user