fix: patch bug in verify first message + add ChatCompletionRequest models to the models dir (#985)

This commit is contained in:
Charles Packer
2024-02-09 17:00:33 -08:00
committed by GitHub
parent 2aa2e12502
commit 0727c802f0
2 changed files with 93 additions and 4 deletions

View File

@@ -0,0 +1,89 @@
from typing import List, Union, Optional, Dict, Literal, Any
from pydantic import BaseModel, Field, Json
class SystemMessage(BaseModel):
content: str
role: str = "system"
name: Optional[str] = None
class UserMessage(BaseModel):
content: Union[str, List[str]]
role: str = "user"
name: Optional[str] = None
class AssistantMessage(BaseModel):
content: Optional[str] = None
role: str = "assistant"
name: Optional[str] = None
tool_calls: Optional[List] = None
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage]
class ResponseFormat(BaseModel):
type: str = Field(default="text", pattern="^(text|json_object)$")
## tool_choice ##
class FunctionCall(BaseModel):
name: str
class ToolFunctionChoice(BaseModel):
# The type of the tool. Currently, only function is supported
type: Literal["function"] = "function"
# type: str = Field(default="function", const=True)
function: FunctionCall
ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
## tools ##
class FunctionSchema(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters
class Tool(BaseModel):
# The type of the tool. Currently, only function is supported
type: Literal["function"] = "function"
# type: str = Field(default="function", const=True)
function: FunctionSchema
## function_call ##
FunctionCallChoice = Union[Literal["none", "auto"], FunctionCall]
class ChatCompletionRequest(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/create"""
model: str
messages: List[ChatMessage]
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1
top_p: Optional[float] = 1
user: Optional[str] = None # unique ID of the end-user (for monitoring)
# function-calling related
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = "none"
# deprecated scheme
functions: Optional[List[FunctionSchema]] = None
function_call: Optional[FunctionCallChoice] = None

View File

@@ -668,14 +668,14 @@ def verify_first_message_correctness(
response_message = response.choices[0].message
# First message should be a call to send_message with a non-empty content
if ("function_call" in response_message and response_message.function_call is not None) and (
"tool_calls" in response_message and response_message.tool_calls is not None
if (hasattr(response_message, "function_call") and response_message.function_call is not None) and (
hasattr(response_message, "tool_calls") and response_message.tool_calls is not None
):
printd(f"First message includes both function call AND tool call: {response_message}")
return False
elif "function_call" in response_message and response_message.function_call is not None:
elif hasattr(response_message, "function_call") and response_message.function_call is not None:
function_call = response_message.function_call
elif "tool_calls" in response_message and response_message.tool_calls is not None:
elif hasattr(response_message, "tool_calls") and response_message.tool_calls is not None:
function_call = response_message.tool_calls[0].function
else:
printd(f"First message didn't include function call: {response_message}")