Files
letta-server/memgpt/server/rest_api/interface.py

151 lines
5.4 KiB
Python

import asyncio
import queue
from datetime import datetime
from typing import Optional
import pytz
from memgpt.interface import AgentInterface
from memgpt.data_types import Message
from memgpt.utils import is_utc_datetime
class QueuingInterface(AgentInterface):
"""Messages are queued inside an internal buffer and manually flushed"""
def __init__(self, debug=True):
self.buffer = queue.Queue()
self.debug = debug
def to_list(self):
"""Convert queue to a list (empties it out at the same time)"""
items = []
while not self.buffer.empty():
try:
items.append(self.buffer.get_nowait())
except queue.Empty:
break
if len(items) > 1 and items[-1] == "STOP":
items.pop()
return items
def clear(self):
"""Clear all messages from the queue."""
with self.buffer.mutex:
# Empty the queue
self.buffer.queue.clear()
async def message_generator(self):
while True:
if not self.buffer.empty():
message = self.buffer.get()
if message == "STOP":
break
# yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
yield message
else:
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
def step_yield(self):
"""Enqueue a special stop message"""
self.buffer.put("STOP")
def error(self, error: str):
"""Enqueue a special stop message"""
self.buffer.put({"internal_error": error})
self.buffer.put("STOP")
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Handle reception of a user message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent's internal monologue"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
new_message = {"internal_monologue": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent sending a message"""
# assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
if msg_obj is not None:
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
new_message = {"assistant_message": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
else:
# FIXME this is a total hack
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
# TODO also should not be accessing protected member here
new_message["id"] = self.buffer.queue[-1]["id"]
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = self.buffer.queue[-1]["date"]
self.buffer.put(new_message)
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
"""Handle the agent calling a function"""
# TODO handle 'function' messages that indicate the start of a function call
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
if msg.startswith("Running "):
msg = msg.replace("Running ", "")
new_message = {"function_call": msg}
elif msg.startswith("Ran "):
if not include_ran_messages:
return
msg = msg.replace("Ran ", "Function call returned: ")
new_message = {"function_call": msg}
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
new_message = {"function_return": msg, "status": "success"}
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
new_message = {"function_return": msg, "status": "error"}
else:
# NOTE: generic, should not happen
new_message = {"function_message": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)