Files
letta-server/memgpt/server/rest_api/server.py
Charles Packer b7427e2de7 API server refactor + REST API (#593)
* init server refactor

* refactored websockets server/client code to use internal server API

* added intentional fail on test

* update workflow to try and get test to pass remotely

* refactor to put websocket code in a separate subdirectory

* added fastapi rest server

* add error handling

* modified interface return style

* disabled certain tests on remote

* added SSE response option for user_message

* fix ws interface test

* fallback for oai key

* add soft fail for test when localhost is borked

* add step_yield for all server related interfaces

* extra catch

* update toml + lock with server add-ons (add uvicorn+fastapi, move websockets to server extra)

* regen lock file

* added pytest-asyncio as an extra in dev

* add pydantic to deps

* renamed CreateConfig to CreateAgentConfig

* fixed POST request for creating agent + tested it
2023-12-11 15:08:42 -08:00

113 lines
3.4 KiB
Python

import asyncio
import json
from typing import Union
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.interface import QueuingInterface
import memgpt.utils as utils
"""
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
Start the server with:
cd memgpt/server/rest_api
poetry run uvicorn server:app --reload
"""
class CreateAgentConfig(BaseModel):
user_id: str
config: dict
class UserMessage(BaseModel):
user_id: str
agent_id: str
message: str
stream: bool = False
class Command(BaseModel):
user_id: str
agent_id: str
command: str
app = FastAPI()
interface = QueuingInterface()
server = SyncServer(default_interface=interface)
# server.list_agents
@app.get("/agents")
def list_agents(user_id: str):
interface.clear()
agents_list = utils.list_agent_config_files()
return {"num_agents": len(agents_list), "agent_names": agents_list}
# server.create_agent
@app.post("/agents")
def create_agents(body: CreateAgentConfig):
interface.clear()
try:
agent_id = server.create_agent(user_id=body.user_id, agent_config=body.config)
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return {"agent_id": agent_id}
# server.user_message
@app.post("/agents/message")
async def user_message(body: UserMessage):
if body.stream:
# For streaming response
try:
# Start the generation process (similar to the non-streaming case)
# This should be a non-blocking call or run in a background task
# Check if server.user_message is an async function
if asyncio.iscoroutinefunction(server.user_message):
# Start the async task
asyncio.create_task(server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message))
else:
# Run the synchronous function in a thread pool
loop = asyncio.get_event_loop()
loop.run_in_executor(None, server.user_message, body.user_id, body.agent_id, body.message)
async def formatted_message_generator():
async for message in interface.message_generator():
formatted_message = f"data: {json.dumps(message)}\n\n"
yield formatted_message
await asyncio.sleep(1)
# Return the streaming response using the generator
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
else:
interface.clear()
try:
server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return {"messages": interface.to_list()}
# server.run_command
@app.post("/agents/command")
def run_command(body: Command):
interface.clear()
try:
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
return {"response": response}