fix: add max tokens (#795)

Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
Kevin Lin
2025-02-10 20:28:03 -08:00
committed by GitHub
parent 308e6ce215
commit e858871ccb
3 changed files with 12 additions and 9 deletions

View File

@@ -111,7 +111,6 @@ def create(
# streaming?
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
max_tokens: Optional[int] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
@@ -157,7 +156,7 @@ def create(
else:
function_call = "required"
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming)
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
@@ -212,7 +211,7 @@ def create(
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
llm_config.model_endpoint = model_settings.azure_base_url
chat_completion_request = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
llm_config, messages, user_id, functions, function_call, use_tool_naming
)
response = azure_openai_chat_completions_request(
@@ -248,7 +247,7 @@ def create(
data=dict(
contents=[m.to_google_ai_dict() for m in messages],
tools=tools,
generation_config={"temperature": llm_config.temperature},
generation_config={"temperature": llm_config.temperature, "max_output_tokens": llm_config.max_tokens},
),
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
)
@@ -268,7 +267,7 @@ def create(
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=([{"type": "function", "function": f} for f in functions] if functions else None),
tool_choice=tool_call,
max_tokens=1024, # TODO make dynamic
max_tokens=llm_config.max_tokens, # Note: max_tokens is required for Anthropic API
temperature=llm_config.temperature,
stream=stream,
)
@@ -416,7 +415,7 @@ def create(
tool_choice=tool_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
max_tokens=llm_config.max_tokens,
),
)