* First commit of memgpt client and some messy test code * rolled back unnecessary changes to abstract interface; switched client to always use Queueing Interface * Added missing interface clear() in run_command; added convenience method for checking if an agent exists, used that in create_agent * Formatting fixes * Fixed incorrect naming of get_agent_memory in rest server * Removed erroneous clear from client save method; Replaced print statements with appropriate logger calls in server * Updated readme with client usage instructions * added tests for Client * make printing to terminal togglable on queininginterface (should probably refactor this to a logger) * turn off printing to stdout via interface by default * allow importing the python client in a similar fashion to openai-python (see https://github.com/openai/openai-python) * Allowed quickstart on init of client; updated readme and test_client accordingly * oops, fixed name of openai_api_key config key * Fixed small typo * Fixed broken test by adding memgpt hosted model details to agent config * silence llamaindex 'LLM is explicitly disabled. Using MockLLM.' on server * default to openai if user's memgpt directory is empty (first time) * correct type hint * updated section on client in readme * added comment about how MemGPT config != Agent config * patch unrelated test * update wording on readme * patch another unrelated test * added python client to readme docs * Changed 'user' to 'human' in example; Defaulted AgentConfig.model to 'None'; Fixed issue in create_agent (accounting for dict config); matched test code to example * Fixed advanced example * patch test * patch --------- Co-authored-by: cpacker <packercharles@gmail.com>
167 lines
4.7 KiB
Python
167 lines
4.7 KiB
Python
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import json
|
|
from typing import Union
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from memgpt.server.server import SyncServer
|
|
from memgpt.server.rest_api.interface import QueuingInterface
|
|
import memgpt.utils as utils
|
|
|
|
|
|
"""
|
|
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
|
|
|
|
Start the server with:
|
|
cd memgpt/server/rest_api
|
|
poetry run uvicorn server:app --reload
|
|
"""
|
|
|
|
|
|
class CreateAgentConfig(BaseModel):
|
|
user_id: str
|
|
config: dict
|
|
|
|
|
|
class UserMessage(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
message: str
|
|
stream: bool = False
|
|
|
|
|
|
class Command(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
command: str
|
|
|
|
|
|
class CoreMemory(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
human: str | None = None
|
|
persona: str | None = None
|
|
|
|
|
|
server = None
|
|
interface = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(application: FastAPI):
|
|
global server
|
|
global interface
|
|
interface = QueuingInterface()
|
|
server = SyncServer(default_interface=interface)
|
|
yield
|
|
server.save_agents()
|
|
server = None
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# app = FastAPI()
|
|
# server = SyncServer(default_interface=interface)
|
|
|
|
|
|
# server.list_agents
|
|
@app.get("/agents")
|
|
def list_agents(user_id: str):
|
|
interface.clear()
|
|
return server.list_agents(user_id=user_id)
|
|
|
|
|
|
@app.get("/agents/memory")
|
|
def get_agent_memory(user_id: str, agent_id: str):
|
|
interface.clear()
|
|
return server.get_agent_memory(user_id=user_id, agent_id=agent_id)
|
|
|
|
|
|
@app.put("/agents/memory")
|
|
def put_agent_memory(body: CoreMemory):
|
|
interface.clear()
|
|
new_memory_contents = {"persona": body.persona, "human": body.human}
|
|
return server.update_agent_core_memory(user_id=body.user_id, agent_id=body.agent_id, new_memory_contents=new_memory_contents)
|
|
|
|
|
|
@app.get("/agents/config")
|
|
def get_agent_config(user_id: str, agent_id: str):
|
|
interface.clear()
|
|
return server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
|
|
|
|
|
@app.get("/config")
|
|
def get_server_config(user_id: str):
|
|
interface.clear()
|
|
return server.get_server_config(user_id=user_id)
|
|
|
|
|
|
# server.create_agent
|
|
@app.post("/agents")
|
|
def create_agents(body: CreateAgentConfig):
|
|
interface.clear()
|
|
try:
|
|
agent_id = server.create_agent(user_id=body.user_id, agent_config=body.config)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"{e}")
|
|
return {"agent_id": agent_id}
|
|
|
|
|
|
# server.user_message
|
|
@app.post("/agents/message")
|
|
async def user_message(body: UserMessage):
|
|
if body.stream:
|
|
# For streaming response
|
|
try:
|
|
# Start the generation process (similar to the non-streaming case)
|
|
# This should be a non-blocking call or run in a background task
|
|
|
|
# Check if server.user_message is an async function
|
|
if asyncio.iscoroutinefunction(server.user_message):
|
|
# Start the async task
|
|
asyncio.create_task(server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message))
|
|
else:
|
|
# Run the synchronous function in a thread pool
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_in_executor(None, server.user_message, body.user_id, body.agent_id, body.message)
|
|
|
|
async def formatted_message_generator():
|
|
async for message in interface.message_generator():
|
|
formatted_message = f"data: {json.dumps(message)}\n\n"
|
|
yield formatted_message
|
|
await asyncio.sleep(1)
|
|
|
|
# Return the streaming response using the generator
|
|
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"{e}")
|
|
|
|
else:
|
|
interface.clear()
|
|
try:
|
|
server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"{e}")
|
|
return {"messages": interface.to_list()}
|
|
|
|
|
|
# server.run_command
|
|
@app.post("/agents/command")
|
|
def run_command(body: Command):
|
|
interface.clear()
|
|
try:
|
|
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"{e}")
|
|
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
|
|
return {"response": response}
|