fix: add max tokens (#795)

Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
Kevin Lin
2025-02-10 20:28:03 -08:00
committed by GitHub
parent 308e6ce215
commit e858871ccb
3 changed files with 12 additions and 9 deletions

View File

@@ -111,7 +111,6 @@ def create(
# streaming?
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
max_tokens: Optional[int] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
@@ -157,7 +156,7 @@ def create(
else:
function_call = "required"
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming)
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
@@ -212,7 +211,7 @@ def create(
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
llm_config.model_endpoint = model_settings.azure_base_url
chat_completion_request = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
llm_config, messages, user_id, functions, function_call, use_tool_naming
)
response = azure_openai_chat_completions_request(
@@ -248,7 +247,7 @@ def create(
data=dict(
contents=[m.to_google_ai_dict() for m in messages],
tools=tools,
generation_config={"temperature": llm_config.temperature},
generation_config={"temperature": llm_config.temperature, "max_output_tokens": llm_config.max_tokens},
),
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
)
@@ -268,7 +267,7 @@ def create(
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=([{"type": "function", "function": f} for f in functions] if functions else None),
tool_choice=tool_call,
max_tokens=1024, # TODO make dynamic
max_tokens=llm_config.max_tokens, # Note: max_tokens is required for Anthropic API
temperature=llm_config.temperature,
stream=stream,
)
@@ -416,7 +415,7 @@ def create(
tool_choice=tool_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
max_tokens=llm_config.max_tokens,
),
)

View File

@@ -94,7 +94,6 @@ def build_openai_chat_completions_request(
functions: Optional[list],
function_call: Optional[str],
use_tool_naming: bool,
max_tokens: Optional[int],
) -> ChatCompletionRequest:
if functions and llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
@@ -131,7 +130,7 @@ def build_openai_chat_completions_request(
tools=[Tool(type="function", function=f) for f in functions] if functions else None,
tool_choice=tool_choice,
user=str(user_id),
max_completion_tokens=max_tokens,
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature,
)
else:
@@ -141,7 +140,7 @@ def build_openai_chat_completions_request(
functions=functions,
function_call=function_call,
user=str(user_id),
max_completion_tokens=max_tokens,
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature,
)
# https://platform.openai.com/docs/guides/text-generation/json-mode

View File

@@ -15,6 +15,7 @@ class LLMConfig(BaseModel):
context_window (int): The context window size for the model.
put_inner_thoughts_in_kwargs (bool): Puts `inner_thoughts` as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.
temperature (float): The temperature to use when generating text with the model. A higher temperature will result in more random text.
max_tokens (int): The maximum number of tokens to generate.
"""
# TODO: 🤮 don't default to a vendor! bug city!
@@ -51,6 +52,10 @@ class LLMConfig(BaseModel):
0.7,
description="The temperature to use when generating text with the model. A higher temperature will result in more random text.",
)
max_tokens: Optional[int] = Field(
1024,
description="The maximum number of tokens to generate. If not set, the model will use its default value.",
)
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())