import asyncio import queue from typing import Optional from memgpt.data_types import Message from memgpt.interface import AgentInterface from memgpt.utils import is_utc_datetime class QueuingInterface(AgentInterface): """Messages are queued inside an internal buffer and manually flushed""" def __init__(self, debug=True): self.buffer = queue.Queue() self.debug = debug def to_list(self): """Convert queue to a list (empties it out at the same time)""" items = [] while not self.buffer.empty(): try: items.append(self.buffer.get_nowait()) except queue.Empty: break if len(items) > 1 and items[-1] == "STOP": items.pop() return items def clear(self): """Clear all messages from the queue.""" with self.buffer.mutex: # Empty the queue self.buffer.queue.clear() async def message_generator(self): while True: if not self.buffer.empty(): message = self.buffer.get() if message == "STOP": break # yield message | {"date": datetime.now(tz=pytz.utc).isoformat()} yield message else: await asyncio.sleep(0.1) # Small sleep to prevent a busy loop def step_yield(self): """Enqueue a special stop message""" self.buffer.put("STOP") def error(self, error: str): """Enqueue a special stop message""" self.buffer.put({"internal_error": error}) self.buffer.put("STOP") def user_message(self, msg: str, msg_obj: Optional[Message] = None): """Handle reception of a user message""" assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) print(vars(msg_obj)) print(msg_obj.created_at.isoformat()) def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent's internal monologue""" assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) print(vars(msg_obj)) print(msg_obj.created_at.isoformat()) new_message = {"internal_monologue": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() self.buffer.put(new_message) def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent sending a message""" # assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) if msg_obj is not None: print(vars(msg_obj)) print(msg_obj.created_at.isoformat()) new_message = {"assistant_message": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() else: # FIXME this is a total hack assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message." # TODO also should not be accessing protected member here new_message["id"] = self.buffer.queue[-1]["id"] # assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = self.buffer.queue[-1]["date"] self.buffer.put(new_message) def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None: """Handle the agent calling a function""" # TODO handle 'function' messages that indicate the start of a function call assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) print(vars(msg_obj)) print(msg_obj.created_at.isoformat()) if msg.startswith("Running "): msg = msg.replace("Running ", "") new_message = {"function_call": msg} elif msg.startswith("Ran "): if not include_ran_messages: return msg = msg.replace("Ran ", "Function call returned: ") new_message = {"function_call": msg} elif msg.startswith("Success: "): msg = msg.replace("Success: ", "") new_message = {"function_return": msg, "status": "success"} elif msg.startswith("Error: "): msg = msg.replace("Error: ", "") new_message = {"function_return": msg, "status": "error"} else: # NOTE: generic, should not happen new_message = {"function_message": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() self.buffer.put(new_message)