feat: Partial support for OpenAI-compatible assistant API (#838)
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_storage and not test_server" tests
|
||||
poetry run pytest -s -vv -k "not test_storage and not test_server and not test_openai_client" tests
|
||||
|
||||
- name: Run storage tests
|
||||
env:
|
||||
|
||||
51
examples/openai_client_assistants.py
Normal file
51
examples/openai_client_assistants.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from openai import OpenAI
|
||||
import time
|
||||
|
||||
"""
|
||||
This script provides an example of how you can use OpenAI's python client with a MemGPT server.
|
||||
|
||||
Before running this example, make sure you start the OpenAI-compatible REST server:
|
||||
cd memgpt/server/rest_api/openai_assistants
|
||||
poetry run uvicorn assistants:app --reload --port 8080
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(base_url="http://127.0.0.1:8080/v1")
|
||||
|
||||
# create assistant (creates a memgpt preset)
|
||||
assistant = client.beta.assistants.create(
|
||||
name="Math Tutor",
|
||||
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||
model="gpt-4-turbo-preview",
|
||||
)
|
||||
|
||||
# create thread (creates a memgpt agent)
|
||||
thread = client.beta.threads.create()
|
||||
|
||||
# create a message (appends a message to the memgpt agent)
|
||||
message = client.beta.threads.messages.create(
|
||||
thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?"
|
||||
)
|
||||
|
||||
# create a run (executes the agent on the messages)
|
||||
# NOTE: MemGPT does not support polling yet, so run status is always "completed"
|
||||
run = client.beta.threads.runs.create(
|
||||
thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account."
|
||||
)
|
||||
|
||||
# Store the run ID
|
||||
run_id = run.id
|
||||
|
||||
# Retrieve all messages from the thread
|
||||
messages = client.beta.threads.messages.list(thread_id=thread.id)
|
||||
|
||||
# Print all messages from the thread
|
||||
for msg in messages.messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"][0]
|
||||
print(f"{role.capitalize()}: {content}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,7 +4,7 @@ import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
from typing import List, Tuple, Optional, cast
|
||||
from typing import List, Tuple, Optional, cast, Union
|
||||
|
||||
|
||||
from memgpt.data_types import AgentState, Message, EmbeddingConfig
|
||||
@@ -546,7 +546,7 @@ class Agent(object):
|
||||
|
||||
def step(
|
||||
self,
|
||||
user_message: Optional[str], # NOTE: should be json.dump(dict)
|
||||
user_message: Union[Message, str], # NOTE: should be json.dump(dict)
|
||||
first_message: bool = False,
|
||||
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
||||
skip_verify: bool = False,
|
||||
@@ -556,11 +556,18 @@ class Agent(object):
|
||||
try:
|
||||
# Step 0: add user message
|
||||
if user_message is not None:
|
||||
# Create the user message dict
|
||||
self.interface.user_message(user_message)
|
||||
packed_user_message = {"role": "user", "content": user_message}
|
||||
if isinstance(user_message, Message):
|
||||
user_message_text = user_message.text
|
||||
elif isinstance(user_message, str):
|
||||
user_message_text = user_message
|
||||
else:
|
||||
raise ValueError(f"Bad type for user_message: {type(user_message)}")
|
||||
|
||||
self.interface.user_message(user_message_text)
|
||||
packed_user_message = {"role": "user", "content": user_message_text}
|
||||
# Special handling for AutoGen messages with 'name' field
|
||||
try:
|
||||
user_message_json = json.loads(user_message, strict=JSON_LOADS_STRICT)
|
||||
user_message_json = json.loads(user_message_text, strict=JSON_LOADS_STRICT)
|
||||
# Special handling for AutoGen messages with 'name' field
|
||||
# Treat 'name' as a special field
|
||||
# If it exists in the input message, elevate it to the 'message' level
|
||||
@@ -629,7 +636,17 @@ class Agent(object):
|
||||
|
||||
# Step 4: extend the message history
|
||||
if user_message is not None:
|
||||
all_new_messages = [packed_user_message_obj] + all_response_messages
|
||||
if isinstance(user_message, Message):
|
||||
all_new_messages = [user_message] + all_response_messages
|
||||
else:
|
||||
all_new_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict=packed_user_message,
|
||||
)
|
||||
] + all_response_messages
|
||||
else:
|
||||
all_new_messages = all_response_messages
|
||||
|
||||
|
||||
157
memgpt/models/openai.py
Normal file
157
memgpt/models/openai.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import List, Union, Optional, Dict, Literal
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, Json
|
||||
import uuid
|
||||
|
||||
|
||||
class ImageFile(BaseModel):
|
||||
type: str = "image_file"
|
||||
file_id: str
|
||||
|
||||
|
||||
class Text(BaseModel):
|
||||
object: str = "text"
|
||||
text: str = Field(..., description="The text content to be processed by the agent.")
|
||||
|
||||
|
||||
class MessageRoleType(str, Enum):
|
||||
user = "user"
|
||||
system = "system"
|
||||
|
||||
|
||||
class OpenAIAssistant(BaseModel):
|
||||
"""Represents an OpenAI assistant (equivalent to MemGPT preset)"""
|
||||
|
||||
id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
name: str = Field(..., description="The name of the assistant.")
|
||||
object: str = "assistant"
|
||||
description: Optional[str] = Field(None, description="The description of the assistant.")
|
||||
created_at: int = Field(..., description="The unix timestamp of when the assistant was created.")
|
||||
model: str = Field(..., description="The model used by the assistant.")
|
||||
instructions: str = Field(..., description="The instructions for the assistant.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the assistant.")
|
||||
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the assistant.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the assistant.")
|
||||
|
||||
|
||||
class OpenAIMessage(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the message.")
|
||||
object: str = "thread.message"
|
||||
created_at: int = Field(..., description="The unix timestamp of when the message was created.")
|
||||
thread_id: str = Field(..., description="The unique identifier of the thread.")
|
||||
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
|
||||
content: List[Union[Text, ImageFile]] = Field(None, description="The message content to be processed by the agent.")
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
run_id: Optional[str] = Field(None, description="The unique identifier of the run.")
|
||||
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
|
||||
metadata: Optional[Dict] = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class MessageFile(BaseModel):
|
||||
id: str
|
||||
object: str = "thread.message.file"
|
||||
created_at: int # unix timestamp
|
||||
|
||||
|
||||
class OpenAIThread(BaseModel):
|
||||
"""Represents an OpenAI thread (equivalent to MemGPT agent)"""
|
||||
|
||||
id: str = Field(..., description="The unique identifier of the thread.")
|
||||
object: str = "thread"
|
||||
created_at: int = Field(..., description="The unix timestamp of when the thread was created.")
|
||||
metadata: dict = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
|
||||
class AssistantFile(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the file.")
|
||||
object: str = "assistant.file"
|
||||
created_at: int = Field(..., description="The unix timestamp of when the file was created.")
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
|
||||
|
||||
class MessageFile(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the file.")
|
||||
object: str = "thread.message.file"
|
||||
created_at: int = Field(..., description="The unix timestamp of when the file was created.")
|
||||
message_id: str = Field(..., description="The unique identifier of the message.")
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
arguments: str = Field(..., description="The arguments of the function.")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the tool call.")
|
||||
type: str = "function"
|
||||
function: Function = Field(..., description="The function call.")
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_id: str = Field(..., description="The unique identifier of the tool call.")
|
||||
output: str = Field(..., description="The output of the tool call.")
|
||||
|
||||
|
||||
class RequiredAction(BaseModel):
|
||||
type: str = "submit_tool_outputs"
|
||||
submit_tool_outputs: List[ToolCall]
|
||||
|
||||
|
||||
class OpenAIError(BaseModel):
|
||||
code: str = Field(..., description="The error code.")
|
||||
message: str = Field(..., description="The error message.")
|
||||
|
||||
|
||||
class OpenAIUsage(BaseModel):
|
||||
completion_tokens: int = Field(..., description="The number of tokens used for the run.")
|
||||
prompt_tokens: int = Field(..., description="The number of tokens used for the prompt.")
|
||||
total_tokens: int = Field(..., description="The total number of tokens used for the run.")
|
||||
|
||||
|
||||
class OpenAIMessageCreationStep(BaseModel):
|
||||
type: str = "message_creation"
|
||||
message_id: str = Field(..., description="The unique identifier of the message.")
|
||||
|
||||
|
||||
class OpenAIToolCallsStep(BaseModel):
|
||||
type: str = "tool_calls"
|
||||
tool_calls: List[ToolCall] = Field(..., description="The tool calls.")
|
||||
|
||||
|
||||
class OpenAIRun(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the run.")
|
||||
object: str = "thread.run"
|
||||
created_at: int = Field(..., description="The unix timestamp of when the run was created.")
|
||||
thread_id: str = Field(..., description="The unique identifier of the thread.")
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
status: str = Field(..., description="The status of the run.")
|
||||
required_action: Optional[RequiredAction] = Field(None, description="The required action of the run.")
|
||||
last_error: Optional[OpenAIError] = Field(None, description="The last error of the run.")
|
||||
expires_at: int = Field(..., description="The unix timestamp of when the run expires.")
|
||||
started_at: Optional[int] = Field(None, description="The unix timestamp of when the run started.")
|
||||
cancelled_at: Optional[int] = Field(None, description="The unix timestamp of when the run was cancelled.")
|
||||
failed_at: Optional[int] = Field(None, description="The unix timestamp of when the run failed.")
|
||||
completed_at: Optional[int] = Field(None, description="The unix timestamp of when the run completed.")
|
||||
model: str = Field(..., description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run.") # TODO: also add code interpreter / retrieval
|
||||
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the run.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
usage: Optional[OpenAIUsage] = Field(None, description="The usage of the run.")
|
||||
|
||||
|
||||
class OpenAIRunStep(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the run step.")
|
||||
object: str = "thread.run.step"
|
||||
created_at: int = Field(..., description="The unix timestamp of when the run step was created.")
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
thread_id: str = Field(..., description="The unique identifier of the thread.")
|
||||
run_id: str = Field(..., description="The unique identifier of the run.")
|
||||
type: str = Field(..., description="The type of the run step.") # message_creation, tool_calls
|
||||
status: str = Field(..., description="The status of the run step.")
|
||||
step_defaults: Union[OpenAIToolCallsStep, OpenAIMessageCreationStep] = Field(..., description="The step defaults.")
|
||||
last_error: Optional[OpenAIError] = Field(None, description="The last error of the run step.")
|
||||
expired_at: Optional[int] = Field(None, description="The unix timestamp of when the run step expired.")
|
||||
failed_at: Optional[int] = Field(None, description="The unix timestamp of when the run failed.")
|
||||
completed_at: Optional[int] = Field(None, description="The unix timestamp of when the run completed.")
|
||||
usage: Optional[OpenAIUsage] = Field(None, description="The usage of the run.")
|
||||
532
memgpt/server/rest_api/openai_assistants/assistants.py
Normal file
532
memgpt/server/rest_api/openai_assistants/assistants.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import asyncio
|
||||
from fastapi import FastAPI
|
||||
from asyncio import AbstractEventLoop
|
||||
from enum import Enum
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException, Query, Path
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
import uuid
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.models.openai import (
|
||||
AssistantFile,
|
||||
MessageFile,
|
||||
OpenAIAssistant,
|
||||
OpenAIThread,
|
||||
OpenAIMessage,
|
||||
OpenAIRun,
|
||||
OpenAIRunStep,
|
||||
MessageRoleType,
|
||||
Text,
|
||||
ImageFile,
|
||||
ToolCall,
|
||||
ToolCallOutput,
|
||||
)
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig, Message
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
"""
|
||||
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
|
||||
|
||||
Start the server with:
|
||||
cd memgpt/server/rest_api/openai_assistants
|
||||
poetry run uvicorn assistants:app --reload --port 8080
|
||||
"""
|
||||
|
||||
interface: QueuingInterface = QueuingInterface()
|
||||
server: SyncServer = SyncServer(default_interface=interface)
|
||||
|
||||
|
||||
# router = APIRouter()
|
||||
app = FastAPI()
|
||||
|
||||
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
|
||||
print(f"User ID: {user_id}")
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str = Field(..., description="The model to use for the assistant.")
|
||||
name: str = Field(..., description="The name of the assistant.")
|
||||
description: str = Field(None, description="The description of the assistant.")
|
||||
instructions: str = Field(..., description="The instructions for the assistant.")
|
||||
tools: List[str] = Field(None, description="The tools used by the assistant.")
|
||||
file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.")
|
||||
metadata: dict = Field(None, description="Metadata associated with the assistant.")
|
||||
|
||||
# memgpt-only (not openai)
|
||||
embedding_model: str = Field(None, description="The model to use for the assistant.")
|
||||
|
||||
## TODO: remove
|
||||
# user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
# memgpt-only
|
||||
assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)")
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class ModifyRunRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
|
||||
content: str = Field(..., description="The message content to be processed by the agent.")
|
||||
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
|
||||
|
||||
class UserMessageResponse(BaseModel):
|
||||
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
||||
|
||||
|
||||
class GetAgentMessagesRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
messages: List[OpenAIMessage] = Field(..., description="List of message objects.")
|
||||
|
||||
|
||||
class CreateAssistantFileRequest(BaseModel):
|
||||
file_id: str = Field(..., description="The unique identifier of the file.")
|
||||
|
||||
|
||||
class CreateRunRequest(BaseModel):
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
model: Optional[str] = Field(None, description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class CreateThreadRunRequest(BaseModel):
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
thread: OpenAIThread = Field(..., description="The thread to run.")
|
||||
model: str = Field(..., description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the agent.")
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool = Field(..., description="Whether the agent was deleted.")
|
||||
|
||||
|
||||
class DeleteAssistantFileResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the file.")
|
||||
object: str = "assistant.file.deleted"
|
||||
deleted: bool = Field(..., description="Whether the file was deleted.")
|
||||
|
||||
|
||||
class DeleteThreadResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the agent.")
|
||||
object: str = "thread.deleted"
|
||||
deleted: bool = Field(..., description="Whether the agent was deleted.")
|
||||
|
||||
|
||||
class SubmitToolOutputsToRunRequest(BaseModel):
|
||||
tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.")
|
||||
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
|
||||
|
||||
@app.get("/v1/health", tags=["assistant"])
|
||||
def get_health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
# create assistant (MemGPT agent)
|
||||
@app.post("/v1/assistants", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def create_assistant(request: CreateAssistantRequest = Body(...)):
|
||||
# TODO: create preset
|
||||
return OpenAIAssistant(
|
||||
id=DEFAULT_PRESET,
|
||||
name="default_preset",
|
||||
description=request.description,
|
||||
created_at=int(datetime.now().timestamp()),
|
||||
model=request.model,
|
||||
instructions=request.instructions,
|
||||
tools=request.tools,
|
||||
file_ids=request.file_ids,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
|
||||
def create_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantFileRequest = Body(...),
|
||||
):
|
||||
# TODO: add file to assistant
|
||||
return AssistantFile(
|
||||
id=request.file_id,
|
||||
created_at=int(datetime.now().timestamp()),
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
|
||||
def list_assistants(
|
||||
limit: int = Query(1000, description="How many assistants to retrieve."),
|
||||
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: implement list assistants (i.e. list available MemGPT presets)
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
|
||||
def list_assistant_files(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
limit: int = Query(1000, description="How many files to retrieve."),
|
||||
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: list attached data sources to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def retrieve_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: get and return preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
|
||||
def retrieve_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: return data source attached to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def modify_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantRequest = Body(...),
|
||||
):
|
||||
# TODO: modify preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.delete("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
|
||||
def delete_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: delete preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.delete("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
|
||||
def delete_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: delete source on preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads", tags=["assistants"], response_model=OpenAIThread)
|
||||
def create_thread(request: CreateThreadRequest = Body(...)):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a memgpt agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
agent_config={
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
# TODO: insert messages into recall memory
|
||||
return OpenAIThread(
|
||||
id=str(agent_state.id),
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
agent = server.get_agent(uuid.UUID(thread_id))
|
||||
return OpenAIThread(
|
||||
id=str(agent.id),
|
||||
created_at=int(agent.created_at.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
|
||||
def modify_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: ModifyThreadRequest = Body(...),
|
||||
):
|
||||
# TODO: add agent metadata so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.delete("/v1/threads/{thread_id}", tags=["assistants"], response_model=DeleteThreadResponse)
|
||||
def delete_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
# TODO: delete agent
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=OpenAIMessage)
|
||||
def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
):
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# create message object
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role=request.role,
|
||||
text=request.content,
|
||||
)
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
# add message to agent
|
||||
agent._append_to_messages([message])
|
||||
|
||||
openai_message = OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
return openai_message
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=ListMessagesResponse)
|
||||
def list_messages(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many messages to retrieve."),
|
||||
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
after_uuid = uuid.UUID(after) if before else None
|
||||
before_uuid = uuid.UUID(before) if before else None
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
reverse = True if (order == "desc") else False
|
||||
cursor, json_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after_uuid,
|
||||
before=before_uuid,
|
||||
order_by="created_at",
|
||||
reverse=reverse,
|
||||
)
|
||||
print(json_messages[0]["text"])
|
||||
# convert to openai style messages
|
||||
openai_messages = [
|
||||
OpenAIMessage(
|
||||
id=str(message["id"]),
|
||||
created_at=int(message["created_at"].timestamp()),
|
||||
content=[Text(text=message["text"])],
|
||||
role=message["role"],
|
||||
thread_id=str(message["agent_id"]),
|
||||
assistant_id=DEFAULT_PRESET # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
for message in json_messages
|
||||
]
|
||||
print("MESSAGES", openai_messages)
|
||||
# TODO: cast back to message objects
|
||||
return ListMessagesResponse(messages=openai_messages)
|
||||
|
||||
|
||||
app.get("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
|
||||
|
||||
|
||||
def retrieve_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
):
|
||||
message_id = uuid.UUID(message_id)
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
message = server.get_agent_message(agent_id, message_id)
|
||||
return OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["assistants"], response_model=MessageFile)
|
||||
def retrieve_message_file(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: implement?
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
|
||||
def modify_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
request: ModifyMessageRequest = Body(...),
|
||||
):
|
||||
# TODO: add metada field to message so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=OpenAIRun)
|
||||
def create_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
agent.step(user_message=None) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(datetime.now().timestamp())
|
||||
return OpenAIRun(
|
||||
id=run_id,
|
||||
created_at=create_time,
|
||||
thread_id=str(agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
status="completed", # TODO: eventaully allow offline execution
|
||||
expires_at=create_time,
|
||||
model=agent.agent_state.llm_config.model,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/threads/runs", tags=["assistants"], response_model=OpenAIRun)
|
||||
def create_thread_and_run(
|
||||
request: CreateThreadRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add a bunch of messages and execute
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=List[OpenAIRun])
|
||||
def list_runs(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many runs to retrieve."),
|
||||
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps", tags=["assistants"], response_model=List[OpenAIRunStep])
|
||||
def list_run_steps(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
limit: int = Query(1000, description="How many run steps to retrieve."),
|
||||
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
|
||||
def retrieve_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["assistants"], response_model=OpenAIRunStep)
|
||||
def retrieve_run_step(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
step_id: str = Path(..., description="The unique identifier of the run step."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
|
||||
def modify_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: ModifyRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["assistants"], response_model=OpenAIRun)
|
||||
def submit_tool_outputs_to_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: SubmitToolOutputsToRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs/{run_id}/cancel", tags=["assistants"], response_model=OpenAIRun)
|
||||
def cancel_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@@ -25,6 +25,12 @@ from memgpt.data_types import (
|
||||
AgentState,
|
||||
LLMConfig,
|
||||
EmbeddingConfig,
|
||||
Message,
|
||||
ToolCall,
|
||||
LLMConfig,
|
||||
EmbeddingConfig,
|
||||
Message,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
|
||||
@@ -292,7 +298,7 @@ class SyncServer(LockingServer):
|
||||
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
|
||||
return memgpt_agent
|
||||
|
||||
def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: str) -> int:
|
||||
def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: Union[str, Message]) -> int:
|
||||
"""Send the input message through the agent"""
|
||||
|
||||
logger.debug(f"Got input message: {input_message}")
|
||||
@@ -451,7 +457,7 @@ class SyncServer(LockingServer):
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
||||
|
||||
@LockingServer.agent_lock_decorator
|
||||
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
|
||||
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message]) -> None:
|
||||
"""Process an incoming user message and feed it through the MemGPT agent"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
@@ -459,21 +465,27 @@ class SyncServer(LockingServer):
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
|
||||
# Basic input sanitization
|
||||
if not isinstance(message, str) or len(message) == 0:
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
if isinstance(message, str):
|
||||
if len(message) == 0:
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# If the input begins with a command prefix, reject
|
||||
elif message.startswith("/"):
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# Else, process it as a user message to be fed to the agent
|
||||
else:
|
||||
# Package the user message first
|
||||
# If the input begins with a command prefix, reject
|
||||
elif message.startswith("/"):
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
packaged_user_message = system.package_user_message(user_message=message)
|
||||
# Run the agent state forward
|
||||
tokens_accumulated = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
|
||||
elif isinstance(message, Message):
|
||||
if len(message.text) == 0:
|
||||
raise ValueError(f"Invalid input: '{message.text}'")
|
||||
|
||||
return tokens_accumulated
|
||||
# If the input begins with a command prefix, reject
|
||||
elif message.text.startswith("/"):
|
||||
raise ValueError(f"Invalid input: '{message.text}'")
|
||||
packaged_user_message = message
|
||||
else:
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# Run the agent state forward
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
|
||||
|
||||
@LockingServer.agent_lock_decorator
|
||||
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
|
||||
@@ -665,6 +677,19 @@ class SyncServer(LockingServer):
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
return [m.id for m in memgpt_agent._messages]
|
||||
|
||||
def get_agent_message(self, agent_id: uuid.UUID, message_id: uuid.UUID) -> Message:
|
||||
"""Get message based on agent and message ID"""
|
||||
agent_state = self.ms.get_agent(agent_id=agent_id)
|
||||
if agent_state is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
user_id = agent_state.user_id
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
message = memgpt_agent.persistence_manager.recall_memory.storage.get(message_id=message_id)
|
||||
return message
|
||||
|
||||
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
@@ -757,6 +782,7 @@ class SyncServer(LockingServer):
|
||||
before: Optional[uuid.UUID] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
order: Optional[str] = "asc",
|
||||
reverse: Optional[bool] = False,
|
||||
):
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
|
||||
50
tests/test_openai_assistant_api.py
Normal file
50
tests/test_openai_assistant_api.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
import uuid
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import app
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
|
||||
def test_list_messages():
|
||||
client = TestClient(app)
|
||||
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
# create user
|
||||
server = SyncServer()
|
||||
server.create_user({"id": test_user_id})
|
||||
|
||||
# test: create agent
|
||||
request_body = {
|
||||
"user_id": str(test_user_id),
|
||||
"assistant_name": DEFAULT_PRESET,
|
||||
}
|
||||
print(request_body)
|
||||
response = client.post("/v1/threads", json=request_body)
|
||||
assert response.status_code == 200, f"Error: {response.json()}"
|
||||
agent_id = response.json()["id"]
|
||||
print(response.json())
|
||||
|
||||
# test: insert messages
|
||||
# TODO: eventually implement the "run" functionality
|
||||
request_body = {
|
||||
"user_id": str(test_user_id),
|
||||
"content": "Hello, world!",
|
||||
"role": "user",
|
||||
}
|
||||
response = client.post(f"/v1/threads/{str(agent_id)}/messages", json=request_body)
|
||||
assert response.status_code == 200, f"Error: {response.json()}"
|
||||
|
||||
# test: list messages
|
||||
thread_id = str(agent_id)
|
||||
params = {
|
||||
"limit": 10,
|
||||
"order": "desc",
|
||||
# "after": "",
|
||||
"user_id": str(test_user_id),
|
||||
}
|
||||
response = client.get(f"/v1/threads/{thread_id}/messages", params=params)
|
||||
assert response.status_code == 200, f"Error: {response.json()}"
|
||||
print(response.json())
|
||||
89
tests/test_openai_client.py
Normal file
89
tests/test_openai_client.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from openai import OpenAI
|
||||
import time
|
||||
import uvicorn
|
||||
|
||||
|
||||
def test_openai_assistant():
|
||||
client = OpenAI(base_url="http://127.0.0.1:8080/v1")
|
||||
|
||||
# create assistant
|
||||
assistant = client.beta.assistants.create(
|
||||
name="Math Tutor",
|
||||
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||
# tools=[{"type": "code_interpreter"}],
|
||||
model="gpt-4-turbo-preview",
|
||||
)
|
||||
|
||||
# create thread
|
||||
thread = client.beta.threads.create()
|
||||
|
||||
message = client.beta.threads.messages.create(
|
||||
thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?"
|
||||
)
|
||||
|
||||
run = client.beta.threads.runs.create(
|
||||
thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account."
|
||||
)
|
||||
|
||||
# run = client.beta.threads.runs.create(
|
||||
# thread_id=thread.id,
|
||||
# assistant_id=assistant.id,
|
||||
# model="gpt-4-turbo-preview",
|
||||
# instructions="New instructions that override the Assistant instructions",
|
||||
# tools=[{"type": "code_interpreter"}, {"type": "retrieval"}]
|
||||
# )
|
||||
|
||||
# Store the run ID
|
||||
run_id = run.id
|
||||
print(run_id)
|
||||
|
||||
# NOTE: MemGPT does not support polling yet, so run status is always "completed"
|
||||
# Retrieve all messages from the thread
|
||||
messages = client.beta.threads.messages.list(thread_id=thread.id)
|
||||
|
||||
# Print all messages from the thread
|
||||
for msg in messages.messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"][0]
|
||||
print(f"{role.capitalize()}: {content}")
|
||||
|
||||
# TODO: add once polling works
|
||||
## Polling for the run status
|
||||
# while True:
|
||||
# # Retrieve the run status
|
||||
# run_status = client.beta.threads.runs.retrieve(
|
||||
# thread_id=thread.id,
|
||||
# run_id=run_id
|
||||
# )
|
||||
|
||||
# # Check and print the step details
|
||||
# run_steps = client.beta.threads.runs.steps.list(
|
||||
# thread_id=thread.id,
|
||||
# run_id=run_id
|
||||
# )
|
||||
# for step in run_steps.data:
|
||||
# if step.type == 'tool_calls':
|
||||
# print(f"Tool {step.type} invoked.")
|
||||
|
||||
# # If step involves code execution, print the code
|
||||
# if step.type == 'code_interpreter':
|
||||
# print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")
|
||||
|
||||
# if run_status.status == 'completed':
|
||||
# # Retrieve all messages from the thread
|
||||
# messages = client.beta.threads.messages.list(
|
||||
# thread_id=thread.id
|
||||
# )
|
||||
|
||||
# # Print all messages from the thread
|
||||
# for msg in messages.data:
|
||||
# role = msg.role
|
||||
# content = msg.content[0].text.value
|
||||
# print(f"{role.capitalize()}: {content}")
|
||||
# break # Exit the polling loop since the run is complete
|
||||
# elif run_status.status in ['queued', 'in_progress']:
|
||||
# print(f'{run_status.status.capitalize()}... Please wait.')
|
||||
# time.sleep(1.5) # Wait before checking again
|
||||
# else:
|
||||
# print(f"Run status: {run_status.status}")
|
||||
# break # Exit the polling loop if the status is neither 'in_progress' nor 'completed'
|
||||
Reference in New Issue
Block a user