import asyncio from datetime import datetime from functools import partial from typing import List, Optional, Union from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from memgpt.schemas.enums import MessageRole, MessageStreamStatus from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.memgpt_request import MemGPTRequest from memgpt.schemas.memgpt_response import MemGPTResponse from memgpt.schemas.message import Message, UpdateMessage from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface from memgpt.server.rest_api.utils import sse_async_generator from memgpt.server.server import SyncServer from memgpt.utils import deduplicate router = APIRouter() # TODO: cpacker should check this file # TODO: move this into server.py? async def send_message_to_agent( server: SyncServer, agent_id: str, user_id: str, role: MessageRole, message: str, stream_steps: bool, stream_tokens: bool, return_message_object: bool, # Should be True for Python Client, False for REST API chat_completion_mode: Optional[bool] = False, timestamp: Optional[datetime] = None, # related to whether or not we return `MemGPTMessage`s or `Message`s ) -> Union[StreamingResponse, MemGPTResponse]: """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" # TODO: @charles is this the correct way to handle? include_final_message = True # determine role if role == MessageRole.user: message_func = server.user_message elif role == MessageRole.system: message_func = server.system_message else: raise HTTPException(status_code=500, detail=f"Bad role {role}") if not stream_steps and stream_tokens: raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'") # For streaming response try: # TODO: move this logic into server.py # Get the generator object off of the agent's streaming interface # This will be attached to the POST SSE request used under-the-hood memgpt_agent = server._get_or_load_agent(agent_id=agent_id) streaming_interface = memgpt_agent.interface if not isinstance(streaming_interface, StreamingServerInterface): raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") # Enable token-streaming within the request if desired streaming_interface.streaming_mode = stream_tokens # "chatcompletion mode" does some remapping and ignores inner thoughts streaming_interface.streaming_chat_completion_mode = chat_completion_mode # streaming_interface.allow_assistant_message = stream # streaming_interface.function_call_legacy_mode = stream # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() task = asyncio.create_task( asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp) ) if stream_steps: if return_message_object: # TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format raise NotImplementedError # return a stream return StreamingResponse( sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message), media_type="text/event-stream", ) else: # buffer the stream, then return the list generated_stream = [] async for message in streaming_interface.get_generator(): assert ( isinstance(message, MemGPTMessage) or isinstance(message, LegacyMemGPTMessage) or isinstance(message, MessageStreamStatus) ), type(message) generated_stream.append(message) if message == MessageStreamStatus.done: break # Get rid of the stream status messages filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] usage = await task # By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage # If we want to convert these to Message, we can use the attached IDs # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead if return_message_object: message_ids = [m.id for m in filtered_stream] message_ids = deduplicate(message_ids) message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] return MemGPTResponse(messages=message_objs, usage=usage) else: return MemGPTResponse(messages=filtered_stream, usage=usage) except HTTPException: raise except Exception as e: print(e) import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) @router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message]) def get_agent_messages( agent_id: str, before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."), user_id: str = Depends(get_current_user_with_server), ): """ Retrieve message history for an agent. """ interface.clear() return server.get_agent_recall_cursor( user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True, return_message_object=msg_object, ) @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse) async def send_message( # background_tasks: BackgroundTasks, agent_id: str, request: MemGPTRequest = Body(...), user_id: str = Depends(get_current_user_with_server), ): """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. It can optionally stream the response if 'stream' is set to True. """ # TODO: should this recieve multiple messages? @cpacker # TODO: revise to `MemGPTRequest` # TODO: support sending multiple messages assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}" message = request.messages[0] # TODO: what to do with message.name? return await send_message_to_agent( server=server, agent_id=agent_id, user_id=user_id, role=message.role, message=message.text, stream_steps=request.stream_steps, stream_tokens=request.stream_tokens, return_message_object=request.return_message_object, ) @router.patch("/agents/{agent_id}/messages/{message_id}", tags=["agents"], response_model=Message) async def update_message( agent_id: str, message_id: str, request: UpdateMessage = Body(...), user_id: str = Depends(get_current_user_with_server), ): """ Update the details of a message associated with an agent. """ assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}" return server.update_agent_message(agent_id=agent_id, request=request) return router