fix: add max tokens (#795)
Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
@@ -111,7 +111,6 @@ def create(
|
||||
# streaming?
|
||||
stream: bool = False,
|
||||
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
model_settings: Optional[dict] = None, # TODO: eventually pass from server
|
||||
) -> ChatCompletionResponse:
|
||||
"""Return response to chat completion with backoff"""
|
||||
@@ -157,7 +156,7 @@ def create(
|
||||
else:
|
||||
function_call = "required"
|
||||
|
||||
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
|
||||
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming)
|
||||
if stream: # Client requested token streaming
|
||||
data.stream = True
|
||||
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
|
||||
@@ -212,7 +211,7 @@ def create(
|
||||
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
|
||||
llm_config.model_endpoint = model_settings.azure_base_url
|
||||
chat_completion_request = build_openai_chat_completions_request(
|
||||
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
|
||||
llm_config, messages, user_id, functions, function_call, use_tool_naming
|
||||
)
|
||||
|
||||
response = azure_openai_chat_completions_request(
|
||||
@@ -248,7 +247,7 @@ def create(
|
||||
data=dict(
|
||||
contents=[m.to_google_ai_dict() for m in messages],
|
||||
tools=tools,
|
||||
generation_config={"temperature": llm_config.temperature},
|
||||
generation_config={"temperature": llm_config.temperature, "max_output_tokens": llm_config.max_tokens},
|
||||
),
|
||||
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
@@ -268,7 +267,7 @@ def create(
|
||||
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
||||
tools=([{"type": "function", "function": f} for f in functions] if functions else None),
|
||||
tool_choice=tool_call,
|
||||
max_tokens=1024, # TODO make dynamic
|
||||
max_tokens=llm_config.max_tokens, # Note: max_tokens is required for Anthropic API
|
||||
temperature=llm_config.temperature,
|
||||
stream=stream,
|
||||
)
|
||||
@@ -416,7 +415,7 @@ def create(
|
||||
tool_choice=tool_call,
|
||||
# user=str(user_id),
|
||||
# NOTE: max_tokens is required for Anthropic API
|
||||
max_tokens=1024, # TODO make dynamic
|
||||
max_tokens=llm_config.max_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -94,7 +94,6 @@ def build_openai_chat_completions_request(
|
||||
functions: Optional[list],
|
||||
function_call: Optional[str],
|
||||
use_tool_naming: bool,
|
||||
max_tokens: Optional[int],
|
||||
) -> ChatCompletionRequest:
|
||||
if functions and llm_config.put_inner_thoughts_in_kwargs:
|
||||
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
|
||||
@@ -131,7 +130,7 @@ def build_openai_chat_completions_request(
|
||||
tools=[Tool(type="function", function=f) for f in functions] if functions else None,
|
||||
tool_choice=tool_choice,
|
||||
user=str(user_id),
|
||||
max_completion_tokens=max_tokens,
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=llm_config.temperature,
|
||||
)
|
||||
else:
|
||||
@@ -141,7 +140,7 @@ def build_openai_chat_completions_request(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
user=str(user_id),
|
||||
max_completion_tokens=max_tokens,
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=llm_config.temperature,
|
||||
)
|
||||
# https://platform.openai.com/docs/guides/text-generation/json-mode
|
||||
|
||||
@@ -15,6 +15,7 @@ class LLMConfig(BaseModel):
|
||||
context_window (int): The context window size for the model.
|
||||
put_inner_thoughts_in_kwargs (bool): Puts `inner_thoughts` as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.
|
||||
temperature (float): The temperature to use when generating text with the model. A higher temperature will result in more random text.
|
||||
max_tokens (int): The maximum number of tokens to generate.
|
||||
"""
|
||||
|
||||
# TODO: 🤮 don't default to a vendor! bug city!
|
||||
@@ -51,6 +52,10 @@ class LLMConfig(BaseModel):
|
||||
0.7,
|
||||
description="The temperature to use when generating text with the model. A higher temperature will result in more random text.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
1024,
|
||||
description="The maximum number of tokens to generate. If not set, the model will use its default value.",
|
||||
)
|
||||
|
||||
# FIXME hack to silence pydantic protected namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
Reference in New Issue
Block a user