127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
import asyncio
|
|
import queue
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import pytz
|
|
|
|
from memgpt.interface import AgentInterface
|
|
from memgpt.data_types import Message
|
|
|
|
|
|
class QueuingInterface(AgentInterface):
|
|
"""Messages are queued inside an internal buffer and manually flushed"""
|
|
|
|
def __init__(self, debug=True):
|
|
self.buffer = queue.Queue()
|
|
self.debug = debug
|
|
|
|
def to_list(self):
|
|
"""Convert queue to a list (empties it out at the same time)"""
|
|
items = []
|
|
while not self.buffer.empty():
|
|
try:
|
|
items.append(self.buffer.get_nowait())
|
|
except queue.Empty:
|
|
break
|
|
if len(items) > 1 and items[-1] == "STOP":
|
|
items.pop()
|
|
return items
|
|
|
|
def clear(self):
|
|
"""Clear all messages from the queue."""
|
|
with self.buffer.mutex:
|
|
# Empty the queue
|
|
self.buffer.queue.clear()
|
|
|
|
async def message_generator(self):
|
|
while True:
|
|
if not self.buffer.empty():
|
|
message = self.buffer.get()
|
|
if message == "STOP":
|
|
break
|
|
# yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
|
|
yield message
|
|
else:
|
|
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
|
|
|
|
def step_yield(self):
|
|
"""Enqueue a special stop message"""
|
|
self.buffer.put("STOP")
|
|
|
|
def error(self, error: str):
|
|
"""Enqueue a special stop message"""
|
|
self.buffer.put({"internal_error": error})
|
|
self.buffer.put("STOP")
|
|
|
|
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
|
"""Handle reception of a user message"""
|
|
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
|
|
|
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
|
"""Handle the agent's internal monologue"""
|
|
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
|
if self.debug:
|
|
print(msg)
|
|
|
|
new_message = {"internal_monologue": msg}
|
|
|
|
# add extra metadata
|
|
if msg_obj is not None:
|
|
new_message["id"] = str(msg_obj.id)
|
|
new_message["date"] = msg_obj.created_at.isoformat()
|
|
|
|
self.buffer.put(new_message)
|
|
|
|
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
|
"""Handle the agent sending a message"""
|
|
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
|
if self.debug:
|
|
print(msg)
|
|
|
|
new_message = {"assistant_message": msg}
|
|
|
|
# add extra metadata
|
|
if msg_obj is not None:
|
|
new_message["id"] = str(msg_obj.id)
|
|
new_message["date"] = msg_obj.created_at.isoformat()
|
|
|
|
self.buffer.put(new_message)
|
|
|
|
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
|
|
"""Handle the agent calling a function"""
|
|
# TODO handle 'function' messages that indicate the start of a function call
|
|
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
|
|
|
if self.debug:
|
|
print(msg)
|
|
|
|
if msg.startswith("Running "):
|
|
msg = msg.replace("Running ", "")
|
|
new_message = {"function_call": msg}
|
|
|
|
elif msg.startswith("Ran "):
|
|
if not include_ran_messages:
|
|
return
|
|
msg = msg.replace("Ran ", "Function call returned: ")
|
|
new_message = {"function_call": msg}
|
|
|
|
elif msg.startswith("Success: "):
|
|
msg = msg.replace("Success: ", "")
|
|
new_message = {"function_return": msg, "status": "success"}
|
|
|
|
elif msg.startswith("Error: "):
|
|
msg = msg.replace("Error: ", "")
|
|
new_message = {"function_return": msg, "status": "error"}
|
|
|
|
else:
|
|
# NOTE: generic, should not happen
|
|
new_message = {"function_message": msg}
|
|
|
|
# add extra metadata
|
|
if msg_obj is not None:
|
|
new_message["id"] = str(msg_obj.id)
|
|
new_message["date"] = msg_obj.created_at.isoformat()
|
|
|
|
self.buffer.put(new_message)
|