feat: Partial support for OpenAI-compatible assistant API (#838)

This commit is contained in:
Sarah Wooders
2024-02-13 16:09:20 -08:00
committed by GitHub
parent 75fd604d7a
commit bf252b90f0
8 changed files with 944 additions and 22 deletions

View File

@@ -33,7 +33,7 @@ jobs:
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv -k "not test_storage and not test_server" tests
poetry run pytest -s -vv -k "not test_storage and not test_server and not test_openai_client" tests
- name: Run storage tests
env:

View File

@@ -0,0 +1,51 @@
from openai import OpenAI
import time
"""
This script provides an example of how you can use OpenAI's python client with a MemGPT server.
Before running this example, make sure you start the OpenAI-compatible REST server:
cd memgpt/server/rest_api/openai_assistants
poetry run uvicorn assistants:app --reload --port 8080
"""
def main():
client = OpenAI(base_url="http://127.0.0.1:8080/v1")
# create assistant (creates a memgpt preset)
assistant = client.beta.assistants.create(
name="Math Tutor",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
model="gpt-4-turbo-preview",
)
# create thread (creates a memgpt agent)
thread = client.beta.threads.create()
# create a message (appends a message to the memgpt agent)
message = client.beta.threads.messages.create(
thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?"
)
# create a run (executes the agent on the messages)
# NOTE: MemGPT does not support polling yet, so run status is always "completed"
run = client.beta.threads.runs.create(
thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account."
)
# Store the run ID
run_id = run.id
# Retrieve all messages from the thread
messages = client.beta.threads.messages.list(thread_id=thread.id)
# Print all messages from the thread
for msg in messages.messages:
role = msg["role"]
content = msg["content"][0]
print(f"{role.capitalize()}: {content}")
if __name__ == "__main__":
main()

View File

@@ -4,7 +4,7 @@ import inspect
import json
from pathlib import Path
import traceback
from typing import List, Tuple, Optional, cast
from typing import List, Tuple, Optional, cast, Union
from memgpt.data_types import AgentState, Message, EmbeddingConfig
@@ -546,7 +546,7 @@ class Agent(object):
def step(
self,
user_message: Optional[str], # NOTE: should be json.dump(dict)
user_message: Union[Message, str], # NOTE: should be json.dump(dict)
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
@@ -556,11 +556,18 @@ class Agent(object):
try:
# Step 0: add user message
if user_message is not None:
# Create the user message dict
self.interface.user_message(user_message)
packed_user_message = {"role": "user", "content": user_message}
if isinstance(user_message, Message):
user_message_text = user_message.text
elif isinstance(user_message, str):
user_message_text = user_message
else:
raise ValueError(f"Bad type for user_message: {type(user_message)}")
self.interface.user_message(user_message_text)
packed_user_message = {"role": "user", "content": user_message_text}
# Special handling for AutoGen messages with 'name' field
try:
user_message_json = json.loads(user_message, strict=JSON_LOADS_STRICT)
user_message_json = json.loads(user_message_text, strict=JSON_LOADS_STRICT)
# Special handling for AutoGen messages with 'name' field
# Treat 'name' as a special field
# If it exists in the input message, elevate it to the 'message' level
@@ -629,7 +636,17 @@ class Agent(object):
# Step 4: extend the message history
if user_message is not None:
all_new_messages = [packed_user_message_obj] + all_response_messages
if isinstance(user_message, Message):
all_new_messages = [user_message] + all_response_messages
else:
all_new_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=packed_user_message,
)
] + all_response_messages
else:
all_new_messages = all_response_messages

157
memgpt/models/openai.py Normal file
View File

@@ -0,0 +1,157 @@
from typing import List, Union, Optional, Dict, Literal
from enum import Enum
from pydantic import BaseModel, Field, Json
import uuid
class ImageFile(BaseModel):
type: str = "image_file"
file_id: str
class Text(BaseModel):
object: str = "text"
text: str = Field(..., description="The text content to be processed by the agent.")
class MessageRoleType(str, Enum):
user = "user"
system = "system"
class OpenAIAssistant(BaseModel):
"""Represents an OpenAI assistant (equivalent to MemGPT preset)"""
id: str = Field(..., description="The unique identifier of the assistant.")
name: str = Field(..., description="The name of the assistant.")
object: str = "assistant"
description: Optional[str] = Field(None, description="The description of the assistant.")
created_at: int = Field(..., description="The unix timestamp of when the assistant was created.")
model: str = Field(..., description="The model used by the assistant.")
instructions: str = Field(..., description="The instructions for the assistant.")
tools: Optional[List[str]] = Field(None, description="The tools used by the assistant.")
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the assistant.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the assistant.")
class OpenAIMessage(BaseModel):
id: str = Field(..., description="The unique identifier of the message.")
object: str = "thread.message"
created_at: int = Field(..., description="The unix timestamp of when the message was created.")
thread_id: str = Field(..., description="The unique identifier of the thread.")
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
content: List[Union[Text, ImageFile]] = Field(None, description="The message content to be processed by the agent.")
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
run_id: Optional[str] = Field(None, description="The unique identifier of the run.")
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
metadata: Optional[Dict] = Field(None, description="Metadata associated with the message.")
class MessageFile(BaseModel):
id: str
object: str = "thread.message.file"
created_at: int # unix timestamp
class OpenAIThread(BaseModel):
"""Represents an OpenAI thread (equivalent to MemGPT agent)"""
id: str = Field(..., description="The unique identifier of the thread.")
object: str = "thread"
created_at: int = Field(..., description="The unix timestamp of when the thread was created.")
metadata: dict = Field(None, description="Metadata associated with the thread.")
class AssistantFile(BaseModel):
id: str = Field(..., description="The unique identifier of the file.")
object: str = "assistant.file"
created_at: int = Field(..., description="The unix timestamp of when the file was created.")
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
class MessageFile(BaseModel):
id: str = Field(..., description="The unique identifier of the file.")
object: str = "thread.message.file"
created_at: int = Field(..., description="The unix timestamp of when the file was created.")
message_id: str = Field(..., description="The unique identifier of the message.")
class Function(BaseModel):
name: str = Field(..., description="The name of the function.")
arguments: str = Field(..., description="The arguments of the function.")
class ToolCall(BaseModel):
id: str = Field(..., description="The unique identifier of the tool call.")
type: str = "function"
function: Function = Field(..., description="The function call.")
class ToolCallOutput(BaseModel):
tool_call_id: str = Field(..., description="The unique identifier of the tool call.")
output: str = Field(..., description="The output of the tool call.")
class RequiredAction(BaseModel):
type: str = "submit_tool_outputs"
submit_tool_outputs: List[ToolCall]
class OpenAIError(BaseModel):
code: str = Field(..., description="The error code.")
message: str = Field(..., description="The error message.")
class OpenAIUsage(BaseModel):
completion_tokens: int = Field(..., description="The number of tokens used for the run.")
prompt_tokens: int = Field(..., description="The number of tokens used for the prompt.")
total_tokens: int = Field(..., description="The total number of tokens used for the run.")
class OpenAIMessageCreationStep(BaseModel):
type: str = "message_creation"
message_id: str = Field(..., description="The unique identifier of the message.")
class OpenAIToolCallsStep(BaseModel):
type: str = "tool_calls"
tool_calls: List[ToolCall] = Field(..., description="The tool calls.")
class OpenAIRun(BaseModel):
id: str = Field(..., description="The unique identifier of the run.")
object: str = "thread.run"
created_at: int = Field(..., description="The unix timestamp of when the run was created.")
thread_id: str = Field(..., description="The unique identifier of the thread.")
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
status: str = Field(..., description="The status of the run.")
required_action: Optional[RequiredAction] = Field(None, description="The required action of the run.")
last_error: Optional[OpenAIError] = Field(None, description="The last error of the run.")
expires_at: int = Field(..., description="The unix timestamp of when the run expires.")
started_at: Optional[int] = Field(None, description="The unix timestamp of when the run started.")
cancelled_at: Optional[int] = Field(None, description="The unix timestamp of when the run was cancelled.")
failed_at: Optional[int] = Field(None, description="The unix timestamp of when the run failed.")
completed_at: Optional[int] = Field(None, description="The unix timestamp of when the run completed.")
model: str = Field(..., description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run.") # TODO: also add code interpreter / retrieval
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the run.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
usage: Optional[OpenAIUsage] = Field(None, description="The usage of the run.")
class OpenAIRunStep(BaseModel):
id: str = Field(..., description="The unique identifier of the run step.")
object: str = "thread.run.step"
created_at: int = Field(..., description="The unix timestamp of when the run step was created.")
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
thread_id: str = Field(..., description="The unique identifier of the thread.")
run_id: str = Field(..., description="The unique identifier of the run.")
type: str = Field(..., description="The type of the run step.") # message_creation, tool_calls
status: str = Field(..., description="The status of the run step.")
step_defaults: Union[OpenAIToolCallsStep, OpenAIMessageCreationStep] = Field(..., description="The step defaults.")
last_error: Optional[OpenAIError] = Field(None, description="The last error of the run step.")
expired_at: Optional[int] = Field(None, description="The unix timestamp of when the run step expired.")
failed_at: Optional[int] = Field(None, description="The unix timestamp of when the run failed.")
completed_at: Optional[int] = Field(None, description="The unix timestamp of when the run completed.")
usage: Optional[OpenAIUsage] = Field(None, description="The usage of the run.")

View File

@@ -0,0 +1,532 @@
import asyncio
from fastapi import FastAPI
from asyncio import AbstractEventLoop
from enum import Enum
import json
import uuid
from typing import List, Optional, Union
from datetime import datetime
from fastapi import APIRouter, Depends, Body, HTTPException, Query, Path
from pydantic import BaseModel, Field, constr, validator
from starlette.responses import StreamingResponse
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.config import MemGPTConfig
import uuid
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.models.openai import (
AssistantFile,
MessageFile,
OpenAIAssistant,
OpenAIThread,
OpenAIMessage,
OpenAIRun,
OpenAIRunStep,
MessageRoleType,
Text,
ImageFile,
ToolCall,
ToolCallOutput,
)
from memgpt.data_types import LLMConfig, EmbeddingConfig, Message
from memgpt.constants import DEFAULT_PRESET
"""
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
Start the server with:
cd memgpt/server/rest_api/openai_assistants
poetry run uvicorn assistants:app --reload --port 8080
"""
interface: QueuingInterface = QueuingInterface()
server: SyncServer = SyncServer(default_interface=interface)
# router = APIRouter()
app = FastAPI()
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
class CreateAssistantRequest(BaseModel):
model: str = Field(..., description="The model to use for the assistant.")
name: str = Field(..., description="The name of the assistant.")
description: str = Field(None, description="The description of the assistant.")
instructions: str = Field(..., description="The instructions for the assistant.")
tools: List[str] = Field(None, description="The tools used by the assistant.")
file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.")
metadata: dict = Field(None, description="Metadata associated with the assistant.")
# memgpt-only (not openai)
embedding_model: str = Field(None, description="The model to use for the assistant.")
## TODO: remove
# user_id: str = Field(..., description="The unique identifier of the user.")
class CreateThreadRequest(BaseModel):
messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.")
# memgpt-only
assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)")
class ModifyThreadRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the thread.")
class ModifyMessageRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the message.")
class ModifyRunRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the run.")
class CreateMessageRequest(BaseModel):
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
content: str = Field(..., description="The message content to be processed by the agent.")
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the message.")
class UserMessageRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
class GetAgentMessagesRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class ListMessagesResponse(BaseModel):
messages: List[OpenAIMessage] = Field(..., description="List of message objects.")
class CreateAssistantFileRequest(BaseModel):
file_id: str = Field(..., description="The unique identifier of the file.")
class CreateRunRequest(BaseModel):
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
model: Optional[str] = Field(None, description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
class CreateThreadRunRequest(BaseModel):
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
thread: OpenAIThread = Field(..., description="The thread to run.")
model: str = Field(..., description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
class DeleteAssistantResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
object: str = "assistant.deleted"
deleted: bool = Field(..., description="Whether the agent was deleted.")
class DeleteAssistantFileResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the file.")
object: str = "assistant.file.deleted"
deleted: bool = Field(..., description="Whether the file was deleted.")
class DeleteThreadResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
object: str = "thread.deleted"
deleted: bool = Field(..., description="Whether the agent was deleted.")
class SubmitToolOutputsToRunRequest(BaseModel):
tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.")
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
@app.get("/v1/health", tags=["assistant"])
def get_health():
return {"status": "healthy"}
# create assistant (MemGPT agent)
@app.post("/v1/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# TODO: create preset
return OpenAIAssistant(
id=DEFAULT_PRESET,
name="default_preset",
description=request.description,
created_at=int(datetime.now().timestamp()),
model=request.model,
instructions=request.instructions,
tools=request.tools,
file_ids=request.file_ids,
metadata=request.metadata,
)
@app.post("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
def create_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantFileRequest = Body(...),
):
# TODO: add file to assistant
return AssistantFile(
id=request.file_id,
created_at=int(datetime.now().timestamp()),
assistant_id=assistant_id,
)
@app.get("/v1/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
def list_assistants(
limit: int = Query(1000, description="How many assistants to retrieve."),
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: implement list assistants (i.e. list available MemGPT presets)
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
def list_assistant_files(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
limit: int = Query(1000, description="How many files to retrieve."),
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: list attached data sources to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def retrieve_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: get and return preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
def retrieve_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: return data source attached to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def modify_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantRequest = Body(...),
):
# TODO: modify preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
def delete_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: delete preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: delete source on preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads", tags=["assistants"], response_model=OpenAIThread)
def create_thread(request: CreateThreadRequest = Body(...)):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
print("Create thread/agent", request)
# create a memgpt agent
agent_state = server.create_agent(
user_id=user_id,
agent_config={
"user_id": user_id,
},
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
)
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
agent = server.get_agent(uuid.UUID(thread_id))
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
)
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/threads/{thread_id}", tags=["assistants"], response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
):
agent_id = uuid.UUID(thread_id)
# create message object
message = Message(
user_id=user_id,
agent_id=agent_id,
role=request.role,
text=request.content,
)
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
return openai_message
@app.get("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after_uuid = uuid.UUID(after) if before else None
before_uuid = uuid.UUID(before) if before else None
agent_id = uuid.UUID(thread_id)
reverse = True if (order == "desc") else False
cursor, json_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
)
print(json_messages[0]["text"])
# convert to openai style messages
openai_messages = [
OpenAIMessage(
id=str(message["id"]),
created_at=int(message["created_at"].timestamp()),
content=[Text(text=message["text"])],
role=message["role"],
thread_id=str(message["agent_id"]),
assistant_id=DEFAULT_PRESET # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
for message in json_messages
]
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
app.get("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
):
message_id = uuid.UUID(message_id)
agent_id = uuid.UUID(thread_id)
message = server.get_agent_message(agent_id, message_id)
return OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
@app.get("/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["assistants"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
):
# TODO: add request.instructions as a message?
agent_id = uuid.UUID(thread_id)
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
agent.step(user_message=None) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(datetime.now().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@app.post("/v1/threads/runs", tags=["assistants"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps", tags=["assistants"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["assistants"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["assistants"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}/cancel", tags=["assistants"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")

View File

@@ -25,6 +25,12 @@ from memgpt.data_types import (
AgentState,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
)
@@ -292,7 +298,7 @@ class SyncServer(LockingServer):
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return memgpt_agent
def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: str) -> int:
def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: Union[str, Message]) -> int:
"""Send the input message through the agent"""
logger.debug(f"Got input message: {input_message}")
@@ -451,7 +457,7 @@ class SyncServer(LockingServer):
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
@LockingServer.agent_lock_decorator
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message]) -> None:
"""Process an incoming user message and feed it through the MemGPT agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -459,21 +465,27 @@ class SyncServer(LockingServer):
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Basic input sanitization
if not isinstance(message, str) or len(message) == 0:
raise ValueError(f"Invalid input: '{message}'")
if isinstance(message, str):
if len(message) == 0:
raise ValueError(f"Invalid input: '{message}'")
# If the input begins with a command prefix, reject
elif message.startswith("/"):
raise ValueError(f"Invalid input: '{message}'")
# Else, process it as a user message to be fed to the agent
else:
# Package the user message first
# If the input begins with a command prefix, reject
elif message.startswith("/"):
raise ValueError(f"Invalid input: '{message}'")
packaged_user_message = system.package_user_message(user_message=message)
# Run the agent state forward
tokens_accumulated = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
elif isinstance(message, Message):
if len(message.text) == 0:
raise ValueError(f"Invalid input: '{message.text}'")
return tokens_accumulated
# If the input begins with a command prefix, reject
elif message.text.startswith("/"):
raise ValueError(f"Invalid input: '{message.text}'")
packaged_user_message = message
else:
raise ValueError(f"Invalid input: '{message}'")
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
@LockingServer.agent_lock_decorator
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
@@ -665,6 +677,19 @@ class SyncServer(LockingServer):
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
return [m.id for m in memgpt_agent._messages]
def get_agent_message(self, agent_id: uuid.UUID, message_id: uuid.UUID) -> Message:
"""Get message based on agent and message ID"""
agent_state = self.ms.get_agent(agent_id=agent_id)
if agent_state is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
user_id = agent_state.user_id
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
message = memgpt_agent.persistence_manager.recall_memory.storage.get(message_id=message_id)
return message
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent message queue"""
if self.ms.get_user(user_id=user_id) is None:
@@ -757,6 +782,7 @@ class SyncServer(LockingServer):
before: Optional[uuid.UUID] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
order: Optional[str] = "asc",
reverse: Optional[bool] = False,
):
if self.ms.get_user(user_id=user_id) is None:

View File

@@ -0,0 +1,50 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
import uuid
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.openai_assistants.assistants import app
from memgpt.constants import DEFAULT_PRESET
def test_list_messages():
client = TestClient(app)
test_user_id = uuid.uuid4()
# create user
server = SyncServer()
server.create_user({"id": test_user_id})
# test: create agent
request_body = {
"user_id": str(test_user_id),
"assistant_name": DEFAULT_PRESET,
}
print(request_body)
response = client.post("/v1/threads", json=request_body)
assert response.status_code == 200, f"Error: {response.json()}"
agent_id = response.json()["id"]
print(response.json())
# test: insert messages
# TODO: eventually implement the "run" functionality
request_body = {
"user_id": str(test_user_id),
"content": "Hello, world!",
"role": "user",
}
response = client.post(f"/v1/threads/{str(agent_id)}/messages", json=request_body)
assert response.status_code == 200, f"Error: {response.json()}"
# test: list messages
thread_id = str(agent_id)
params = {
"limit": 10,
"order": "desc",
# "after": "",
"user_id": str(test_user_id),
}
response = client.get(f"/v1/threads/{thread_id}/messages", params=params)
assert response.status_code == 200, f"Error: {response.json()}"
print(response.json())

View File

@@ -0,0 +1,89 @@
from openai import OpenAI
import time
import uvicorn
def test_openai_assistant():
client = OpenAI(base_url="http://127.0.0.1:8080/v1")
# create assistant
assistant = client.beta.assistants.create(
name="Math Tutor",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
# tools=[{"type": "code_interpreter"}],
model="gpt-4-turbo-preview",
)
# create thread
thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?"
)
run = client.beta.threads.runs.create(
thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account."
)
# run = client.beta.threads.runs.create(
# thread_id=thread.id,
# assistant_id=assistant.id,
# model="gpt-4-turbo-preview",
# instructions="New instructions that override the Assistant instructions",
# tools=[{"type": "code_interpreter"}, {"type": "retrieval"}]
# )
# Store the run ID
run_id = run.id
print(run_id)
# NOTE: MemGPT does not support polling yet, so run status is always "completed"
# Retrieve all messages from the thread
messages = client.beta.threads.messages.list(thread_id=thread.id)
# Print all messages from the thread
for msg in messages.messages:
role = msg["role"]
content = msg["content"][0]
print(f"{role.capitalize()}: {content}")
# TODO: add once polling works
## Polling for the run status
# while True:
# # Retrieve the run status
# run_status = client.beta.threads.runs.retrieve(
# thread_id=thread.id,
# run_id=run_id
# )
# # Check and print the step details
# run_steps = client.beta.threads.runs.steps.list(
# thread_id=thread.id,
# run_id=run_id
# )
# for step in run_steps.data:
# if step.type == 'tool_calls':
# print(f"Tool {step.type} invoked.")
# # If step involves code execution, print the code
# if step.type == 'code_interpreter':
# print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")
# if run_status.status == 'completed':
# # Retrieve all messages from the thread
# messages = client.beta.threads.messages.list(
# thread_id=thread.id
# )
# # Print all messages from the thread
# for msg in messages.data:
# role = msg.role
# content = msg.content[0].text.value
# print(f"{role.capitalize()}: {content}")
# break # Exit the polling loop since the run is complete
# elif run_status.status in ['queued', 'in_progress']:
# print(f'{run_status.status.capitalize()}... Please wait.')
# time.sleep(1.5) # Wait before checking again
# else:
# print(f"Run status: {run_status.status}")
# break # Exit the polling loop if the status is neither 'in_progress' nor 'completed'