diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 5ca17caf..683fade5 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel, Field from memgpt.server.rest_api.interface import QueuingInterface @@ -8,8 +8,8 @@ router = APIRouter() class AgentConfigRequest(BaseModel): - user_id: str = Field(..., description="Unique identifier of the user issuing the command.") - agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.") + user_id: str = Field(..., description="Unique identifier of the user requesting the config.") + agent_id: str = Field(..., description="Identifier of the agent whose config is requested.") class AgentConfigResponse(BaseModel): @@ -18,12 +18,17 @@ class AgentConfigResponse(BaseModel): def setup_agents_config_router(server: SyncServer, interface: QueuingInterface): @router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse) - def get_agent_config(request: AgentConfigRequest = Depends()): + def get_agent_config( + user_id: str = Query(..., description="Unique identifier of the user requesting the config."), + agent_id: str = Query(..., description="Identifier of the agent whose config is requested."), + ): """ Retrieve the configuration for a specific agent. This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs. """ + request = AgentConfigRequest(user_id=user_id, agent_id=agent_id) + interface.clear() config = server.get_agent_config(user_id=request.user_id, agent_id=request.agent_id) return AgentConfigResponse(config=config) diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index 3da20106..c4d26849 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -1,6 +1,6 @@ from typing import List -from fastapi import APIRouter, Depends, Body, HTTPException +from fastapi import APIRouter, Depends, Body, Query, HTTPException from pydantic import BaseModel, Field from memgpt.server.rest_api.interface import QueuingInterface @@ -29,12 +29,14 @@ class CreateAgentResponse(BaseModel): def setup_agents_index_router(server: SyncServer, interface: QueuingInterface): @router.get("/agents", tags=["agents"], response_model=ListAgentsResponse) - def list_agents(request: ListAgentsRequest = Depends()): + def list_agents(user_id: str = Query(..., description="Unique identifier of the user.")): """ List all agents associated with a given user. This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. """ + request = ListAgentsRequest(user_id=user_id) + interface.clear() agents_data = server.list_agents(user_id=request.user_id) return ListAgentsResponse(**agents_data) diff --git a/memgpt/server/rest_api/agents/memory.py b/memgpt/server/rest_api/agents/memory.py index 416d109d..7ab690ac 100644 --- a/memgpt/server/rest_api/agents/memory.py +++ b/memgpt/server/rest_api/agents/memory.py @@ -1,6 +1,6 @@ from typing import Optional -from fastapi import APIRouter, Depends, Body +from fastapi import APIRouter, Depends, Body, Query from pydantic import BaseModel, Field from memgpt.server.rest_api.interface import QueuingInterface @@ -29,8 +29,8 @@ class GetAgentMemoryResponse(BaseModel): class UpdateAgentMemoryRequest(BaseModel): user_id: str = Field(..., description="The unique identifier of the user.") agent_id: str = Field(..., description="The unique identifier of the agent.") - human: Optional[str] = Field(None, description="Human element of the core memory.") - persona: Optional[str] = Field(None, description="Persona element of the core memory.") + human: str = Field(None, description="Human element of the core memory.") + persona: str = Field(None, description="Persona element of the core memory.") class UpdateAgentMemoryResponse(BaseModel): @@ -40,12 +40,18 @@ class UpdateAgentMemoryResponse(BaseModel): def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface): @router.get("/agents/memory", tags=["agents"], response_model=GetAgentMemoryResponse) - def get_agent_memory(request: GetAgentMemoryRequest = Depends()): + def get_agent_memory( + user_id: str = Query(..., description="The unique identifier of the user."), + agent_id: str = Query(..., description="The unique identifier of the agent."), + ): """ Retrieve the memory state of a specific agent. This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. """ + # Validate with the Pydantic model (optional) + request = GetAgentMemoryRequest(user_id=user_id, agent_id=agent_id) + interface.clear() memory = server.get_agent_memory(user_id=request.user_id, agent_id=request.agent_id) return GetAgentMemoryResponse(**memory) diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index f0c6a7ad..3ffac6ac 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -1,9 +1,10 @@ import asyncio +from enum import Enum import json -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Depends, Body, HTTPException -from pydantic import BaseModel, Field +from fastapi import APIRouter, Depends, Body, HTTPException, Query +from pydantic import BaseModel, Field, constr, validator from starlette.responses import StreamingResponse from memgpt.server.rest_api.interface import QueuingInterface @@ -12,26 +13,67 @@ from memgpt.server.server import SyncServer router = APIRouter() +class MessageRoleType(str, Enum): + user = "user" + system = "system" + + class UserMessageRequest(BaseModel): user_id: str = Field(..., description="The unique identifier of the user.") agent_id: str = Field(..., description="The unique identifier of the agent.") message: str = Field(..., description="The message content to be processed by the agent.") stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") + role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") class UserMessageResponse(BaseModel): - messages: List[str] = Field(..., description="List of messages generated by the agent.") + messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.") + + +class GetAgentMessagesRequest(BaseModel): + user_id: str = Field(..., description="The unique identifier of the user.") + agent_id: str = Field(..., description="The unique identifier of the agent.") + start: int = Field(..., description="Message index to start on (reverse chronological).") + count: int = Field(..., description="How many messages to retrieve.") + + +class GetAgentMessagesResponse(BaseModel): + messages: list = Field(..., description="List of message objects.") def setup_agents_message_router(server: SyncServer, interface: QueuingInterface): + @router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse) + def get_agent_messages( + user_id: str = Query(..., description="The unique identifier of the user."), + agent_id: str = Query(..., description="The unique identifier of the agent."), + start: int = Query(..., description="Message index to start on (reverse chronological)."), + count: int = Query(..., description="How many messages to retrieve."), + ): + """ + Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate. + """ + # Validate with the Pydantic model (optional) + request = GetAgentMessagesRequest(user_id=user_id, agent_id=agent_id, start=start, count=count) + + interface.clear() + messages = server.get_agent_messages(user_id=request.user_id, agent_id=request.agent_id, start=request.start, count=request.count) + return GetAgentMessagesResponse(messages=messages) + @router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse) - async def user_message(request: UserMessageRequest = Body(...)): + async def send_message(request: UserMessageRequest = Body(...)): """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. It can optionally stream the response if 'stream' is set to True. """ + if request.role == "user" or request.role is None: + message_func = server.user_message + elif request.role == "system": + message_func = server.system_message + else: + raise HTTPException(status_code=500, detail=f"Bad role {request.role}") + if request.stream: # For streaming response try: @@ -39,15 +81,13 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface) # This should be a non-blocking call or run in a background task # Check if server.user_message is an async function - if asyncio.iscoroutinefunction(server.user_message): + if asyncio.iscoroutinefunction(message_func): # Start the async task - await asyncio.create_task( - server.user_message(user_id=request.user_id, agent_id=request.agent_id, message=request.message) - ) + await asyncio.create_task(message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message)) else: # Run the synchronous function in a thread pool loop = asyncio.get_event_loop() - loop.run_in_executor(None, server.user_message, request.user_id, request.agent_id, request.message) + loop.run_in_executor(None, message_func, request.user_id, request.agent_id, request.message) async def formatted_message_generator(): async for message in interface.message_generator(): @@ -65,7 +105,7 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface) else: interface.clear() try: - server.user_message(user_id=request.user_id, agent_id=request.agent_id, message=request.message) + message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message) except HTTPException: raise except Exception as e: diff --git a/memgpt/server/rest_api/config/index.py b/memgpt/server/rest_api/config/index.py index 36771bd4..ffc09760 100644 --- a/memgpt/server/rest_api/config/index.py +++ b/memgpt/server/rest_api/config/index.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, Query, HTTPException from pydantic import BaseModel, Field @@ -9,7 +9,7 @@ router = APIRouter() class ConfigRequest(BaseModel): - user_id: str = Field(..., description="Unique identifier of the user issuing the command.") + user_id: str = Field(..., description="Unique identifier of the user requesting the config.") class ConfigResponse(BaseModel): @@ -18,12 +18,14 @@ class ConfigResponse(BaseModel): def setup_config_index_router(server: SyncServer, interface: QueuingInterface): @router.get("/config", tags=["config"], response_model=ConfigResponse) - def get_server_config(user_id: ConfigRequest = Depends()): + def get_server_config(user_id: str = Query(..., description="Unique identifier of the user requesting the config.")): """ Retrieve the base configuration for the server. """ + request = ConfigRequest(user_id=user_id) + interface.clear() - response = server.get_server_config(user_id=user_id) + response = server.get_server_config(user_id=request.user_id) return ConfigResponse(config=response) return router diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index aff8e3ab..477c456e 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -1,3 +1,4 @@ +import json from contextlib import asynccontextmanager from fastapi import FastAPI @@ -58,6 +59,22 @@ app.include_router(setup_config_index_router(server, interface), prefix=API_PREF mount_static_files(app) +@app.on_event("startup") +def on_startup(): + # Update the OpenAPI schema + if not app.openapi_schema: + app.openapi_schema = app.openapi() + + if app.openapi_schema: + app.openapi_schema["servers"] = [{"url": "http://localhost:8283"}] + app.openapi_schema["info"]["title"] = "MemGPT API" + + # Write out the OpenAPI schema to a file + with open("openapi.json", "w") as file: + print(f"Writing out openapi.json file") + json.dump(app.openapi_schema, file, indent=2) + + @app.on_event("shutdown") def on_shutdown(): global server diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 7ba1cea4..733e6725 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -36,6 +36,11 @@ class Server(object): """List all available agents to a user""" raise NotImplementedError + @abstractmethod + def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list: + """Paginated query of in-context messages in agent message queue""" + raise NotImplementedError + @abstractmethod def get_agent_memory(self, user_id: str, agent_id: str) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" @@ -72,6 +77,11 @@ class Server(object): """Process a message from the user, internally calls step""" raise NotImplementedError + @abstractmethod + def system_message(self, user_id: str, agent_id: str, message: str) -> None: + """Process a message from the system, internally calls step""" + raise NotImplementedError + @abstractmethod def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: """Run a command on the agent, e.g. /memory @@ -406,6 +416,26 @@ class SyncServer(LockingServer): # Run the agent state forward self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) + @LockingServer.agent_lock_decorator + def system_message(self, user_id: str, agent_id: str, message: str) -> None: + """Process an incoming system message and feed it through the MemGPT agent""" + from memgpt.utils import printd + + # Basic input sanitization + if not isinstance(message, str) or len(message) == 0: + raise ValueError(f"Invalid input: '{message}'") + + # If the input begins with a command prefix, reject + elif message.startswith("/"): + raise ValueError(f"Invalid input: '{message}'") + + # Else, process it as a user message to be fed to the agent + else: + # Package the user message first + packaged_system_message = package_system_message(system_message=message) + # Run the agent state forward + self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message) + @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: """Run a command on the agent""" @@ -505,6 +535,29 @@ class SyncServer(LockingServer): return memory_obj + def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list: + """Paginated query of in-context messages in agent message queue""" + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + if start < 0 or count < 0: + raise ValueError("Start and count values should be non-negative") + + # Reverse the list to make it in reverse chronological order + reversed_messages = memgpt_agent.messages[::-1] + + # Check if start is within the range of the list + if start >= len(reversed_messages): + raise IndexError("Start index is out of range") + + # Calculate the end index, ensuring it does not exceed the list length + end_index = min(start + count, len(reversed_messages)) + + # Slice the list for pagination + paginated_messages = reversed_messages[start:end_index] + + return paginated_messages + def get_agent_config(self, user_id: str, agent_id: str) -> dict: """Return the config of an agent""" # Get the agent object (loaded in memory) diff --git a/memgpt/system.py b/memgpt/system.py index 22e20598..f0813488 100644 --- a/memgpt/system.py +++ b/memgpt/system.py @@ -108,6 +108,17 @@ def package_function_response(was_success, response_string, timestamp=None): return json.dumps(packaged_message) +def package_system_message(system_message, message_type="system_alert", time=None): + formatted_time = time if time else get_local_time() + packaged_message = { + "type": message_type, + "message": system_message, + "time": formatted_time, + } + + return json.dumps(packaged_message) + + def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None): context_message = ( f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"