diff --git a/memgpt/models/chat_completion_request.py b/memgpt/models/chat_completion_request.py new file mode 100644 index 00000000..7adbd1ab --- /dev/null +++ b/memgpt/models/chat_completion_request.py @@ -0,0 +1,89 @@ +from typing import List, Union, Optional, Dict, Literal, Any +from pydantic import BaseModel, Field, Json + + +class SystemMessage(BaseModel): + content: str + role: str = "system" + name: Optional[str] = None + + +class UserMessage(BaseModel): + content: Union[str, List[str]] + role: str = "user" + name: Optional[str] = None + + +class AssistantMessage(BaseModel): + content: Optional[str] = None + role: str = "assistant" + name: Optional[str] = None + tool_calls: Optional[List] = None + + +ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage] + + +class ResponseFormat(BaseModel): + type: str = Field(default="text", pattern="^(text|json_object)$") + + +## tool_choice ## +class FunctionCall(BaseModel): + name: str + + +class ToolFunctionChoice(BaseModel): + # The type of the tool. Currently, only function is supported + type: Literal["function"] = "function" + # type: str = Field(default="function", const=True) + function: FunctionCall + + +ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice] + + +## tools ## +class FunctionSchema(BaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters + + +class Tool(BaseModel): + # The type of the tool. Currently, only function is supported + type: Literal["function"] = "function" + # type: str = Field(default="function", const=True) + function: FunctionSchema + + +## function_call ## +FunctionCallChoice = Union[Literal["none", "auto"], FunctionCall] + + +class ChatCompletionRequest(BaseModel): + """https://platform.openai.com/docs/api-reference/chat/create""" + + model: str + messages: List[ChatMessage] + frequency_penalty: Optional[float] = 0 + logit_bias: Optional[Dict[str, int]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + max_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + temperature: Optional[float] = 1 + top_p: Optional[float] = 1 + user: Optional[str] = None # unique ID of the end-user (for monitoring) + + # function-calling related + tools: Optional[List[Tool]] = None + tool_choice: Optional[ToolChoice] = "none" + # deprecated scheme + functions: Optional[List[FunctionSchema]] = None + function_call: Optional[FunctionCallChoice] = None diff --git a/memgpt/utils.py b/memgpt/utils.py index 78dd7f91..05e79526 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -668,14 +668,14 @@ def verify_first_message_correctness( response_message = response.choices[0].message # First message should be a call to send_message with a non-empty content - if ("function_call" in response_message and response_message.function_call is not None) and ( - "tool_calls" in response_message and response_message.tool_calls is not None + if (hasattr(response_message, "function_call") and response_message.function_call is not None) and ( + hasattr(response_message, "tool_calls") and response_message.tool_calls is not None ): printd(f"First message includes both function call AND tool call: {response_message}") return False - elif "function_call" in response_message and response_message.function_call is not None: + elif hasattr(response_message, "function_call") and response_message.function_call is not None: function_call = response_message.function_call - elif "tool_calls" in response_message and response_message.tool_calls is not None: + elif hasattr(response_message, "tool_calls") and response_message.tool_calls is not None: function_call = response_message.tool_calls[0].function else: printd(f"First message didn't include function call: {response_message}")