fix: patch bug in verify first message + add ChatCompletionRequest models to the models dir (#985)
This commit is contained in:
89
memgpt/models/chat_completion_request.py
Normal file
89
memgpt/models/chat_completion_request.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import List, Union, Optional, Dict, Literal, Any
|
||||
from pydantic import BaseModel, Field, Json
|
||||
|
||||
|
||||
class SystemMessage(BaseModel):
|
||||
content: str
|
||||
role: str = "system"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
content: Union[str, List[str]]
|
||||
role: str = "user"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
role: str = "assistant"
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List] = None
|
||||
|
||||
|
||||
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage]
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: str = Field(default="text", pattern="^(text|json_object)$")
|
||||
|
||||
|
||||
## tool_choice ##
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ToolFunctionChoice(BaseModel):
|
||||
# The type of the tool. Currently, only function is supported
|
||||
type: Literal["function"] = "function"
|
||||
# type: str = Field(default="function", const=True)
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
|
||||
|
||||
|
||||
## tools ##
|
||||
class FunctionSchema(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
# The type of the tool. Currently, only function is supported
|
||||
type: Literal["function"] = "function"
|
||||
# type: str = Field(default="function", const=True)
|
||||
function: FunctionSchema
|
||||
|
||||
|
||||
## function_call ##
|
||||
FunctionCallChoice = Union[Literal["none", "auto"], FunctionCall]
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
frequency_penalty: Optional[float] = 0
|
||||
logit_bias: Optional[Dict[str, int]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = 1
|
||||
top_p: Optional[float] = 1
|
||||
user: Optional[str] = None # unique ID of the end-user (for monitoring)
|
||||
|
||||
# function-calling related
|
||||
tools: Optional[List[Tool]] = None
|
||||
tool_choice: Optional[ToolChoice] = "none"
|
||||
# deprecated scheme
|
||||
functions: Optional[List[FunctionSchema]] = None
|
||||
function_call: Optional[FunctionCallChoice] = None
|
||||
@@ -668,14 +668,14 @@ def verify_first_message_correctness(
|
||||
response_message = response.choices[0].message
|
||||
|
||||
# First message should be a call to send_message with a non-empty content
|
||||
if ("function_call" in response_message and response_message.function_call is not None) and (
|
||||
"tool_calls" in response_message and response_message.tool_calls is not None
|
||||
if (hasattr(response_message, "function_call") and response_message.function_call is not None) and (
|
||||
hasattr(response_message, "tool_calls") and response_message.tool_calls is not None
|
||||
):
|
||||
printd(f"First message includes both function call AND tool call: {response_message}")
|
||||
return False
|
||||
elif "function_call" in response_message and response_message.function_call is not None:
|
||||
elif hasattr(response_message, "function_call") and response_message.function_call is not None:
|
||||
function_call = response_message.function_call
|
||||
elif "tool_calls" in response_message and response_message.tool_calls is not None:
|
||||
elif hasattr(response_message, "tool_calls") and response_message.tool_calls is not None:
|
||||
function_call = response_message.tool_calls[0].function
|
||||
else:
|
||||
printd(f"First message didn't include function call: {response_message}")
|
||||
|
||||
Reference in New Issue
Block a user