From bf252b90f04ecdbf1a017c0c72417aedda310486 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 13 Feb 2024 16:09:20 -0800 Subject: [PATCH] feat: Partial support for OpenAI-compatible assistant API (#838) --- .github/workflows/tests.yml | 2 +- examples/openai_client_assistants.py | 51 ++ memgpt/agent.py | 31 +- memgpt/models/openai.py | 157 ++++++ .../rest_api/openai_assistants/assistants.py | 532 ++++++++++++++++++ memgpt/server/server.py | 54 +- tests/test_openai_assistant_api.py | 50 ++ tests/test_openai_client.py | 89 +++ 8 files changed, 944 insertions(+), 22 deletions(-) create mode 100644 examples/openai_client_assistants.py create mode 100644 memgpt/models/openai.py create mode 100644 memgpt/server/rest_api/openai_assistants/assistants.py create mode 100644 tests/test_openai_assistant_api.py create mode 100644 tests/test_openai_client.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 391c9918..0dbf495d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - poetry run pytest -s -vv -k "not test_storage and not test_server" tests + poetry run pytest -s -vv -k "not test_storage and not test_server and not test_openai_client" tests - name: Run storage tests env: diff --git a/examples/openai_client_assistants.py b/examples/openai_client_assistants.py new file mode 100644 index 00000000..e6c82ee0 --- /dev/null +++ b/examples/openai_client_assistants.py @@ -0,0 +1,51 @@ +from openai import OpenAI +import time + +""" +This script provides an example of how you can use OpenAI's python client with a MemGPT server. + +Before running this example, make sure you start the OpenAI-compatible REST server: + cd memgpt/server/rest_api/openai_assistants + poetry run uvicorn assistants:app --reload --port 8080 +""" + + +def main(): + client = OpenAI(base_url="http://127.0.0.1:8080/v1") + + # create assistant (creates a memgpt preset) + assistant = client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + model="gpt-4-turbo-preview", + ) + + # create thread (creates a memgpt agent) + thread = client.beta.threads.create() + + # create a message (appends a message to the memgpt agent) + message = client.beta.threads.messages.create( + thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?" + ) + + # create a run (executes the agent on the messages) + # NOTE: MemGPT does not support polling yet, so run status is always "completed" + run = client.beta.threads.runs.create( + thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account." + ) + + # Store the run ID + run_id = run.id + + # Retrieve all messages from the thread + messages = client.beta.threads.messages.list(thread_id=thread.id) + + # Print all messages from the thread + for msg in messages.messages: + role = msg["role"] + content = msg["content"][0] + print(f"{role.capitalize()}: {content}") + + +if __name__ == "__main__": + main() diff --git a/memgpt/agent.py b/memgpt/agent.py index c3cff1ff..d7bbab94 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -4,7 +4,7 @@ import inspect import json from pathlib import Path import traceback -from typing import List, Tuple, Optional, cast +from typing import List, Tuple, Optional, cast, Union from memgpt.data_types import AgentState, Message, EmbeddingConfig @@ -546,7 +546,7 @@ class Agent(object): def step( self, - user_message: Optional[str], # NOTE: should be json.dump(dict) + user_message: Union[Message, str], # NOTE: should be json.dump(dict) first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, @@ -556,11 +556,18 @@ class Agent(object): try: # Step 0: add user message if user_message is not None: - # Create the user message dict - self.interface.user_message(user_message) - packed_user_message = {"role": "user", "content": user_message} + if isinstance(user_message, Message): + user_message_text = user_message.text + elif isinstance(user_message, str): + user_message_text = user_message + else: + raise ValueError(f"Bad type for user_message: {type(user_message)}") + + self.interface.user_message(user_message_text) + packed_user_message = {"role": "user", "content": user_message_text} + # Special handling for AutoGen messages with 'name' field try: - user_message_json = json.loads(user_message, strict=JSON_LOADS_STRICT) + user_message_json = json.loads(user_message_text, strict=JSON_LOADS_STRICT) # Special handling for AutoGen messages with 'name' field # Treat 'name' as a special field # If it exists in the input message, elevate it to the 'message' level @@ -629,7 +636,17 @@ class Agent(object): # Step 4: extend the message history if user_message is not None: - all_new_messages = [packed_user_message_obj] + all_response_messages + if isinstance(user_message, Message): + all_new_messages = [user_message] + all_response_messages + else: + all_new_messages = [ + Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict=packed_user_message, + ) + ] + all_response_messages else: all_new_messages = all_response_messages diff --git a/memgpt/models/openai.py b/memgpt/models/openai.py new file mode 100644 index 00000000..60d59971 --- /dev/null +++ b/memgpt/models/openai.py @@ -0,0 +1,157 @@ +from typing import List, Union, Optional, Dict, Literal +from enum import Enum +from pydantic import BaseModel, Field, Json +import uuid + + +class ImageFile(BaseModel): + type: str = "image_file" + file_id: str + + +class Text(BaseModel): + object: str = "text" + text: str = Field(..., description="The text content to be processed by the agent.") + + +class MessageRoleType(str, Enum): + user = "user" + system = "system" + + +class OpenAIAssistant(BaseModel): + """Represents an OpenAI assistant (equivalent to MemGPT preset)""" + + id: str = Field(..., description="The unique identifier of the assistant.") + name: str = Field(..., description="The name of the assistant.") + object: str = "assistant" + description: Optional[str] = Field(None, description="The description of the assistant.") + created_at: int = Field(..., description="The unix timestamp of when the assistant was created.") + model: str = Field(..., description="The model used by the assistant.") + instructions: str = Field(..., description="The instructions for the assistant.") + tools: Optional[List[str]] = Field(None, description="The tools used by the assistant.") + file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the assistant.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the assistant.") + + +class OpenAIMessage(BaseModel): + id: str = Field(..., description="The unique identifier of the message.") + object: str = "thread.message" + created_at: int = Field(..., description="The unix timestamp of when the message was created.") + thread_id: str = Field(..., description="The unique identifier of the thread.") + role: str = Field(..., description="Role of the message sender (either 'user' or 'system')") + content: List[Union[Text, ImageFile]] = Field(None, description="The message content to be processed by the agent.") + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + run_id: Optional[str] = Field(None, description="The unique identifier of the run.") + file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.") + metadata: Optional[Dict] = Field(None, description="Metadata associated with the message.") + + +class MessageFile(BaseModel): + id: str + object: str = "thread.message.file" + created_at: int # unix timestamp + + +class OpenAIThread(BaseModel): + """Represents an OpenAI thread (equivalent to MemGPT agent)""" + + id: str = Field(..., description="The unique identifier of the thread.") + object: str = "thread" + created_at: int = Field(..., description="The unix timestamp of when the thread was created.") + metadata: dict = Field(None, description="Metadata associated with the thread.") + + +class AssistantFile(BaseModel): + id: str = Field(..., description="The unique identifier of the file.") + object: str = "assistant.file" + created_at: int = Field(..., description="The unix timestamp of when the file was created.") + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + + +class MessageFile(BaseModel): + id: str = Field(..., description="The unique identifier of the file.") + object: str = "thread.message.file" + created_at: int = Field(..., description="The unix timestamp of when the file was created.") + message_id: str = Field(..., description="The unique identifier of the message.") + + +class Function(BaseModel): + name: str = Field(..., description="The name of the function.") + arguments: str = Field(..., description="The arguments of the function.") + + +class ToolCall(BaseModel): + id: str = Field(..., description="The unique identifier of the tool call.") + type: str = "function" + function: Function = Field(..., description="The function call.") + + +class ToolCallOutput(BaseModel): + tool_call_id: str = Field(..., description="The unique identifier of the tool call.") + output: str = Field(..., description="The output of the tool call.") + + +class RequiredAction(BaseModel): + type: str = "submit_tool_outputs" + submit_tool_outputs: List[ToolCall] + + +class OpenAIError(BaseModel): + code: str = Field(..., description="The error code.") + message: str = Field(..., description="The error message.") + + +class OpenAIUsage(BaseModel): + completion_tokens: int = Field(..., description="The number of tokens used for the run.") + prompt_tokens: int = Field(..., description="The number of tokens used for the prompt.") + total_tokens: int = Field(..., description="The total number of tokens used for the run.") + + +class OpenAIMessageCreationStep(BaseModel): + type: str = "message_creation" + message_id: str = Field(..., description="The unique identifier of the message.") + + +class OpenAIToolCallsStep(BaseModel): + type: str = "tool_calls" + tool_calls: List[ToolCall] = Field(..., description="The tool calls.") + + +class OpenAIRun(BaseModel): + id: str = Field(..., description="The unique identifier of the run.") + object: str = "thread.run" + created_at: int = Field(..., description="The unix timestamp of when the run was created.") + thread_id: str = Field(..., description="The unique identifier of the thread.") + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + status: str = Field(..., description="The status of the run.") + required_action: Optional[RequiredAction] = Field(None, description="The required action of the run.") + last_error: Optional[OpenAIError] = Field(None, description="The last error of the run.") + expires_at: int = Field(..., description="The unix timestamp of when the run expires.") + started_at: Optional[int] = Field(None, description="The unix timestamp of when the run started.") + cancelled_at: Optional[int] = Field(None, description="The unix timestamp of when the run was cancelled.") + failed_at: Optional[int] = Field(None, description="The unix timestamp of when the run failed.") + completed_at: Optional[int] = Field(None, description="The unix timestamp of when the run completed.") + model: str = Field(..., description="The model used by the run.") + instructions: str = Field(..., description="The instructions for the run.") + tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run.") # TODO: also add code interpreter / retrieval + file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the run.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") + usage: Optional[OpenAIUsage] = Field(None, description="The usage of the run.") + + +class OpenAIRunStep(BaseModel): + id: str = Field(..., description="The unique identifier of the run step.") + object: str = "thread.run.step" + created_at: int = Field(..., description="The unix timestamp of when the run step was created.") + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + thread_id: str = Field(..., description="The unique identifier of the thread.") + run_id: str = Field(..., description="The unique identifier of the run.") + type: str = Field(..., description="The type of the run step.") # message_creation, tool_calls + status: str = Field(..., description="The status of the run step.") + step_defaults: Union[OpenAIToolCallsStep, OpenAIMessageCreationStep] = Field(..., description="The step defaults.") + last_error: Optional[OpenAIError] = Field(None, description="The last error of the run step.") + expired_at: Optional[int] = Field(None, description="The unix timestamp of when the run step expired.") + failed_at: Optional[int] = Field(None, description="The unix timestamp of when the run failed.") + completed_at: Optional[int] = Field(None, description="The unix timestamp of when the run completed.") + usage: Optional[OpenAIUsage] = Field(None, description="The usage of the run.") diff --git a/memgpt/server/rest_api/openai_assistants/assistants.py b/memgpt/server/rest_api/openai_assistants/assistants.py new file mode 100644 index 00000000..2fb3a20f --- /dev/null +++ b/memgpt/server/rest_api/openai_assistants/assistants.py @@ -0,0 +1,532 @@ +import asyncio +from fastapi import FastAPI +from asyncio import AbstractEventLoop +from enum import Enum +import json +import uuid +from typing import List, Optional, Union +from datetime import datetime + +from fastapi import APIRouter, Depends, Body, HTTPException, Query, Path +from pydantic import BaseModel, Field, constr, validator +from starlette.responses import StreamingResponse + +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.server import SyncServer + +from memgpt.config import MemGPTConfig +import uuid + +from memgpt.server.server import SyncServer +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.rest_api.static_files import mount_static_files +from memgpt.models.openai import ( + AssistantFile, + MessageFile, + OpenAIAssistant, + OpenAIThread, + OpenAIMessage, + OpenAIRun, + OpenAIRunStep, + MessageRoleType, + Text, + ImageFile, + ToolCall, + ToolCallOutput, +) +from memgpt.data_types import LLMConfig, EmbeddingConfig, Message +from memgpt.constants import DEFAULT_PRESET + +""" +Basic REST API sitting on top of the internal MemGPT python server (SyncServer) + +Start the server with: + cd memgpt/server/rest_api/openai_assistants + poetry run uvicorn assistants:app --reload --port 8080 +""" + +interface: QueuingInterface = QueuingInterface() +server: SyncServer = SyncServer(default_interface=interface) + + +# router = APIRouter() +app = FastAPI() + +user_id = uuid.UUID(MemGPTConfig.load().anon_clientid) +print(f"User ID: {user_id}") + + +class CreateAssistantRequest(BaseModel): + model: str = Field(..., description="The model to use for the assistant.") + name: str = Field(..., description="The name of the assistant.") + description: str = Field(None, description="The description of the assistant.") + instructions: str = Field(..., description="The instructions for the assistant.") + tools: List[str] = Field(None, description="The tools used by the assistant.") + file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.") + metadata: dict = Field(None, description="Metadata associated with the assistant.") + + # memgpt-only (not openai) + embedding_model: str = Field(None, description="The model to use for the assistant.") + + ## TODO: remove + # user_id: str = Field(..., description="The unique identifier of the user.") + + +class CreateThreadRequest(BaseModel): + messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.") + + # memgpt-only + assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)") + + +class ModifyThreadRequest(BaseModel): + metadata: dict = Field(None, description="Metadata associated with the thread.") + + +class ModifyMessageRequest(BaseModel): + metadata: dict = Field(None, description="Metadata associated with the message.") + + +class ModifyRunRequest(BaseModel): + metadata: dict = Field(None, description="Metadata associated with the run.") + + +class CreateMessageRequest(BaseModel): + role: str = Field(..., description="Role of the message sender (either 'user' or 'system')") + content: str = Field(..., description="The message content to be processed by the agent.") + file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the message.") + + +class UserMessageRequest(BaseModel): + user_id: str = Field(..., description="The unique identifier of the user.") + agent_id: str = Field(..., description="The unique identifier of the agent.") + message: str = Field(..., description="The message content to be processed by the agent.") + stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") + role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") + + +class UserMessageResponse(BaseModel): + messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.") + + +class GetAgentMessagesRequest(BaseModel): + user_id: str = Field(..., description="The unique identifier of the user.") + agent_id: str = Field(..., description="The unique identifier of the agent.") + start: int = Field(..., description="Message index to start on (reverse chronological).") + count: int = Field(..., description="How many messages to retrieve.") + + +class ListMessagesResponse(BaseModel): + messages: List[OpenAIMessage] = Field(..., description="List of message objects.") + + +class CreateAssistantFileRequest(BaseModel): + file_id: str = Field(..., description="The unique identifier of the file.") + + +class CreateRunRequest(BaseModel): + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + model: Optional[str] = Field(None, description="The model used by the run.") + instructions: str = Field(..., description="The instructions for the run.") + additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.") + tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).") + metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") + + +class CreateThreadRunRequest(BaseModel): + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + thread: OpenAIThread = Field(..., description="The thread to run.") + model: str = Field(..., description="The model used by the run.") + instructions: str = Field(..., description="The instructions for the run.") + tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).") + metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") + + +class DeleteAssistantResponse(BaseModel): + id: str = Field(..., description="The unique identifier of the agent.") + object: str = "assistant.deleted" + deleted: bool = Field(..., description="Whether the agent was deleted.") + + +class DeleteAssistantFileResponse(BaseModel): + id: str = Field(..., description="The unique identifier of the file.") + object: str = "assistant.file.deleted" + deleted: bool = Field(..., description="Whether the file was deleted.") + + +class DeleteThreadResponse(BaseModel): + id: str = Field(..., description="The unique identifier of the agent.") + object: str = "thread.deleted" + deleted: bool = Field(..., description="Whether the agent was deleted.") + + +class SubmitToolOutputsToRunRequest(BaseModel): + tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.") + + +# TODO: implement mechanism for creating/authenticating users associated with a bearer token + + +@app.get("/v1/health", tags=["assistant"]) +def get_health(): + return {"status": "healthy"} + + +# create assistant (MemGPT agent) +@app.post("/v1/assistants", tags=["assistants"], response_model=OpenAIAssistant) +def create_assistant(request: CreateAssistantRequest = Body(...)): + # TODO: create preset + return OpenAIAssistant( + id=DEFAULT_PRESET, + name="default_preset", + description=request.description, + created_at=int(datetime.now().timestamp()), + model=request.model, + instructions=request.instructions, + tools=request.tools, + file_ids=request.file_ids, + metadata=request.metadata, + ) + + +@app.post("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile) +def create_assistant_file( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + request: CreateAssistantFileRequest = Body(...), +): + # TODO: add file to assistant + return AssistantFile( + id=request.file_id, + created_at=int(datetime.now().timestamp()), + assistant_id=assistant_id, + ) + + +@app.get("/v1/assistants", tags=["assistants"], response_model=List[OpenAIAssistant]) +def list_assistants( + limit: int = Query(1000, description="How many assistants to retrieve."), + order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: implement list assistants (i.e. list available MemGPT presets) + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile]) +def list_assistant_files( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + limit: int = Query(1000, description="How many files to retrieve."), + order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: list attached data sources to preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant) +def retrieve_assistant( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), +): + # TODO: get and return preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile) +def retrieve_assistant_file( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + file_id: str = Path(..., description="The unique identifier of the file."), +): + # TODO: return data source attached to preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant) +def modify_assistant( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + request: CreateAssistantRequest = Body(...), +): + # TODO: modify preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.delete("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse) +def delete_assistant( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), +): + # TODO: delete preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.delete("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse) +def delete_assistant_file( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + file_id: str = Path(..., description="The unique identifier of the file."), +): + # TODO: delete source on preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads", tags=["assistants"], response_model=OpenAIThread) +def create_thread(request: CreateThreadRequest = Body(...)): + # TODO: use requests.description and requests.metadata fields + # TODO: handle requests.file_ids and requests.tools + # TODO: eventually allow request to override embedding/llm model + + print("Create thread/agent", request) + # create a memgpt agent + agent_state = server.create_agent( + user_id=user_id, + agent_config={ + "user_id": user_id, + }, + ) + # TODO: insert messages into recall memory + return OpenAIThread( + id=str(agent_state.id), + created_at=int(agent_state.created_at.timestamp()), + ) + + +@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread) +def retrieve_thread( + thread_id: str = Path(..., description="The unique identifier of the thread."), +): + agent = server.get_agent(uuid.UUID(thread_id)) + return OpenAIThread( + id=str(agent.id), + created_at=int(agent.created_at.timestamp()), + ) + + +@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread) +def modify_thread( + thread_id: str = Path(..., description="The unique identifier of the thread."), + request: ModifyThreadRequest = Body(...), +): + # TODO: add agent metadata so this can be modified + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.delete("/v1/threads/{thread_id}", tags=["assistants"], response_model=DeleteThreadResponse) +def delete_thread( + thread_id: str = Path(..., description="The unique identifier of the thread."), +): + # TODO: delete agent + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=OpenAIMessage) +def create_message( + thread_id: str = Path(..., description="The unique identifier of the thread."), + request: CreateMessageRequest = Body(...), +): + agent_id = uuid.UUID(thread_id) + # create message object + message = Message( + user_id=user_id, + agent_id=agent_id, + role=request.role, + text=request.content, + ) + agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id) + # add message to agent + agent._append_to_messages([message]) + + openai_message = OpenAIMessage( + id=str(message.id), + created_at=int(message.created_at.timestamp()), + content=[Text(text=message.text)], + role=message.role, + thread_id=str(message.agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + # file_ids=message.file_ids, + # metadata=message.metadata, + ) + return openai_message + + +@app.get("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=ListMessagesResponse) +def list_messages( + thread_id: str = Path(..., description="The unique identifier of the thread."), + limit: int = Query(1000, description="How many messages to retrieve."), + order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + after_uuid = uuid.UUID(after) if before else None + before_uuid = uuid.UUID(before) if before else None + agent_id = uuid.UUID(thread_id) + reverse = True if (order == "desc") else False + cursor, json_messages = server.get_agent_recall_cursor( + user_id=user_id, + agent_id=agent_id, + limit=limit, + after=after_uuid, + before=before_uuid, + order_by="created_at", + reverse=reverse, + ) + print(json_messages[0]["text"]) + # convert to openai style messages + openai_messages = [ + OpenAIMessage( + id=str(message["id"]), + created_at=int(message["created_at"].timestamp()), + content=[Text(text=message["text"])], + role=message["role"], + thread_id=str(message["agent_id"]), + assistant_id=DEFAULT_PRESET # TODO: update this + # file_ids=message.file_ids, + # metadata=message.metadata, + ) + for message in json_messages + ] + print("MESSAGES", openai_messages) + # TODO: cast back to message objects + return ListMessagesResponse(messages=openai_messages) + + +app.get("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage) + + +def retrieve_message( + thread_id: str = Path(..., description="The unique identifier of the thread."), + message_id: str = Path(..., description="The unique identifier of the message."), +): + message_id = uuid.UUID(message_id) + agent_id = uuid.UUID(thread_id) + message = server.get_agent_message(agent_id, message_id) + return OpenAIMessage( + id=str(message.id), + created_at=int(message.created_at.timestamp()), + content=[Text(text=message.text)], + role=message.role, + thread_id=str(message.agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + # file_ids=message.file_ids, + # metadata=message.metadata, + ) + + +@app.get("/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["assistants"], response_model=MessageFile) +def retrieve_message_file( + thread_id: str = Path(..., description="The unique identifier of the thread."), + message_id: str = Path(..., description="The unique identifier of the message."), + file_id: str = Path(..., description="The unique identifier of the file."), +): + # TODO: implement? + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage) +def modify_message( + thread_id: str = Path(..., description="The unique identifier of the thread."), + message_id: str = Path(..., description="The unique identifier of the message."), + request: ModifyMessageRequest = Body(...), +): + # TODO: add metada field to message so this can be modified + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=OpenAIRun) +def create_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + request: CreateRunRequest = Body(...), +): + # TODO: add request.instructions as a message? + agent_id = uuid.UUID(thread_id) + # TODO: override preset of agent with request.assistant_id + agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id) + agent.step(user_message=None) # already has messages added + run_id = str(uuid.uuid4()) + create_time = int(datetime.now().timestamp()) + return OpenAIRun( + id=run_id, + created_at=create_time, + thread_id=str(agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + status="completed", # TODO: eventaully allow offline execution + expires_at=create_time, + model=agent.agent_state.llm_config.model, + instructions=request.instructions, + ) + + +@app.post("/v1/threads/runs", tags=["assistants"], response_model=OpenAIRun) +def create_thread_and_run( + request: CreateThreadRunRequest = Body(...), +): + # TODO: add a bunch of messages and execute + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=List[OpenAIRun]) +def list_runs( + thread_id: str = Path(..., description="The unique identifier of the thread."), + limit: int = Query(1000, description="How many runs to retrieve."), + order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: store run information in a DB so it can be returned here + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps", tags=["assistants"], response_model=List[OpenAIRunStep]) +def list_run_steps( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + limit: int = Query(1000, description="How many run steps to retrieve."), + order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: store run information in a DB so it can be returned here + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun) +def retrieve_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["assistants"], response_model=OpenAIRunStep) +def retrieve_run_step( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + step_id: str = Path(..., description="The unique identifier of the run step."), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun) +def modify_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + request: ModifyRunRequest = Body(...), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["assistants"], response_model=OpenAIRun) +def submit_tool_outputs_to_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + request: SubmitToolOutputsToRunRequest = Body(...), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@app.post("/v1/threads/{thread_id}/runs/{run_id}/cancel", tags=["assistants"], response_model=OpenAIRun) +def cancel_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 1de5d568..2a15e4e4 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -25,6 +25,12 @@ from memgpt.data_types import ( AgentState, LLMConfig, EmbeddingConfig, + Message, + ToolCall, + LLMConfig, + EmbeddingConfig, + Message, + ToolCall, ) @@ -292,7 +298,7 @@ class SyncServer(LockingServer): memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id) return memgpt_agent - def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: str) -> int: + def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: Union[str, Message]) -> int: """Send the input message through the agent""" logger.debug(f"Got input message: {input_message}") @@ -451,7 +457,7 @@ class SyncServer(LockingServer): self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) @LockingServer.agent_lock_decorator - def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: + def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message]) -> None: """Process an incoming user message and feed it through the MemGPT agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -459,21 +465,27 @@ class SyncServer(LockingServer): raise ValueError(f"Agent agent_id={agent_id} does not exist") # Basic input sanitization - if not isinstance(message, str) or len(message) == 0: - raise ValueError(f"Invalid input: '{message}'") + if isinstance(message, str): + if len(message) == 0: + raise ValueError(f"Invalid input: '{message}'") - # If the input begins with a command prefix, reject - elif message.startswith("/"): - raise ValueError(f"Invalid input: '{message}'") - - # Else, process it as a user message to be fed to the agent - else: - # Package the user message first + # If the input begins with a command prefix, reject + elif message.startswith("/"): + raise ValueError(f"Invalid input: '{message}'") packaged_user_message = system.package_user_message(user_message=message) - # Run the agent state forward - tokens_accumulated = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) + elif isinstance(message, Message): + if len(message.text) == 0: + raise ValueError(f"Invalid input: '{message.text}'") - return tokens_accumulated + # If the input begins with a command prefix, reject + elif message.text.startswith("/"): + raise ValueError(f"Invalid input: '{message.text}'") + packaged_user_message = message + else: + raise ValueError(f"Invalid input: '{message}'") + + # Run the agent state forward + self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) @LockingServer.agent_lock_decorator def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: @@ -665,6 +677,19 @@ class SyncServer(LockingServer): memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) return [m.id for m in memgpt_agent._messages] + def get_agent_message(self, agent_id: uuid.UUID, message_id: uuid.UUID) -> Message: + """Get message based on agent and message ID""" + agent_state = self.ms.get_agent(agent_id=agent_id) + if agent_state is None: + raise ValueError(f"Agent agent_id={agent_id} does not exist") + user_id = agent_state.user_id + + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + message = memgpt_agent.persistence_manager.recall_memory.storage.get(message_id=message_id) + return message + def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of all messages in agent message queue""" if self.ms.get_user(user_id=user_id) is None: @@ -757,6 +782,7 @@ class SyncServer(LockingServer): before: Optional[uuid.UUID] = None, limit: Optional[int] = 100, order_by: Optional[str] = "created_at", + order: Optional[str] = "asc", reverse: Optional[bool] = False, ): if self.ms.get_user(user_id=user_id) is None: diff --git a/tests/test_openai_assistant_api.py b/tests/test_openai_assistant_api.py new file mode 100644 index 00000000..09f99062 --- /dev/null +++ b/tests/test_openai_assistant_api.py @@ -0,0 +1,50 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient +import uuid + +from memgpt.server.server import SyncServer +from memgpt.server.rest_api.openai_assistants.assistants import app +from memgpt.constants import DEFAULT_PRESET + + +def test_list_messages(): + client = TestClient(app) + + test_user_id = uuid.uuid4() + + # create user + server = SyncServer() + server.create_user({"id": test_user_id}) + + # test: create agent + request_body = { + "user_id": str(test_user_id), + "assistant_name": DEFAULT_PRESET, + } + print(request_body) + response = client.post("/v1/threads", json=request_body) + assert response.status_code == 200, f"Error: {response.json()}" + agent_id = response.json()["id"] + print(response.json()) + + # test: insert messages + # TODO: eventually implement the "run" functionality + request_body = { + "user_id": str(test_user_id), + "content": "Hello, world!", + "role": "user", + } + response = client.post(f"/v1/threads/{str(agent_id)}/messages", json=request_body) + assert response.status_code == 200, f"Error: {response.json()}" + + # test: list messages + thread_id = str(agent_id) + params = { + "limit": 10, + "order": "desc", + # "after": "", + "user_id": str(test_user_id), + } + response = client.get(f"/v1/threads/{thread_id}/messages", params=params) + assert response.status_code == 200, f"Error: {response.json()}" + print(response.json()) diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py new file mode 100644 index 00000000..15fc72f0 --- /dev/null +++ b/tests/test_openai_client.py @@ -0,0 +1,89 @@ +from openai import OpenAI +import time +import uvicorn + + +def test_openai_assistant(): + client = OpenAI(base_url="http://127.0.0.1:8080/v1") + + # create assistant + assistant = client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + # tools=[{"type": "code_interpreter"}], + model="gpt-4-turbo-preview", + ) + + # create thread + thread = client.beta.threads.create() + + message = client.beta.threads.messages.create( + thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?" + ) + + run = client.beta.threads.runs.create( + thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account." + ) + + # run = client.beta.threads.runs.create( + # thread_id=thread.id, + # assistant_id=assistant.id, + # model="gpt-4-turbo-preview", + # instructions="New instructions that override the Assistant instructions", + # tools=[{"type": "code_interpreter"}, {"type": "retrieval"}] + # ) + + # Store the run ID + run_id = run.id + print(run_id) + + # NOTE: MemGPT does not support polling yet, so run status is always "completed" + # Retrieve all messages from the thread + messages = client.beta.threads.messages.list(thread_id=thread.id) + + # Print all messages from the thread + for msg in messages.messages: + role = msg["role"] + content = msg["content"][0] + print(f"{role.capitalize()}: {content}") + + # TODO: add once polling works + ## Polling for the run status + # while True: + # # Retrieve the run status + # run_status = client.beta.threads.runs.retrieve( + # thread_id=thread.id, + # run_id=run_id + # ) + + # # Check and print the step details + # run_steps = client.beta.threads.runs.steps.list( + # thread_id=thread.id, + # run_id=run_id + # ) + # for step in run_steps.data: + # if step.type == 'tool_calls': + # print(f"Tool {step.type} invoked.") + + # # If step involves code execution, print the code + # if step.type == 'code_interpreter': + # print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}") + + # if run_status.status == 'completed': + # # Retrieve all messages from the thread + # messages = client.beta.threads.messages.list( + # thread_id=thread.id + # ) + + # # Print all messages from the thread + # for msg in messages.data: + # role = msg.role + # content = msg.content[0].text.value + # print(f"{role.capitalize()}: {content}") + # break # Exit the polling loop since the run is complete + # elif run_status.status in ['queued', 'in_progress']: + # print(f'{run_status.status.capitalize()}... Please wait.') + # time.sleep(1.5) # Wait before checking again + # else: + # print(f"Run status: {run_status.status}") + # break # Exit the polling loop if the status is neither 'in_progress' nor 'completed'