fix: make togetherai nebius xai etc usable via the openaiprovider (#1981)
Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Kevin Lin <kl2806@columbia.edu>
This commit is contained in:
@@ -75,6 +75,35 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# TODO move into LLMConfig as a field?
|
||||
def supports_structured_output(llm_config: LLMConfig) -> bool:
|
||||
"""Certain providers don't support structured output."""
|
||||
|
||||
# FIXME pretty hacky - turn off for providers we know users will use,
|
||||
# but also don't support structured output
|
||||
if "nebius.com" in llm_config.model_endpoint:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
# TODO move into LLMConfig as a field?
|
||||
def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
|
||||
"""Certain providers require the tool choice to be set to 'auto'."""
|
||||
|
||||
if "nebius.com" in llm_config.model_endpoint:
|
||||
return True
|
||||
# proxy also has this issue (FIXME check)
|
||||
elif llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
return True
|
||||
# same with vLLM (FIXME check)
|
||||
elif llm_config.handle and "vllm" in llm_config.handle:
|
||||
return True
|
||||
else:
|
||||
# will use "required" instead of "auto"
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = None
|
||||
@@ -136,7 +165,7 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
if requires_auto_tool_choice(llm_config):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
@@ -171,11 +200,12 @@ class OpenAIClient(LLMClientBase):
|
||||
if data.tools is not None and len(data.tools) > 0:
|
||||
# Convert to structured output style (which has 'strict' and no optionals)
|
||||
for tool in data.tools:
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
if supports_structured_output(llm_config):
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
|
||||
return data.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user