fix: make togetherai nebius xai etc usable via the openaiprovider (#1981)

Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Kevin Lin <kl2806@columbia.edu>
This commit is contained in:
Charles Packer
2025-05-09 10:50:55 -07:00
committed by GitHub
parent 0e2291434e
commit fce28c73e3
9 changed files with 259 additions and 70 deletions

View File

@@ -75,6 +75,35 @@ def supports_parallel_tool_calling(model: str) -> bool:
return True
# TODO move into LLMConfig as a field?
def supports_structured_output(llm_config: LLMConfig) -> bool:
"""Certain providers don't support structured output."""
# FIXME pretty hacky - turn off for providers we know users will use,
# but also don't support structured output
if "nebius.com" in llm_config.model_endpoint:
return False
else:
return True
# TODO move into LLMConfig as a field?
def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
"""Certain providers require the tool choice to be set to 'auto'."""
if "nebius.com" in llm_config.model_endpoint:
return True
# proxy also has this issue (FIXME check)
elif llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
return True
# same with vLLM (FIXME check)
elif llm_config.handle and "vllm" in llm_config.handle:
return True
else:
# will use "required" instead of "auto"
return False
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = None
@@ -136,7 +165,7 @@ class OpenAIClient(LLMClientBase):
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
if requires_auto_tool_choice(llm_config):
tool_choice = "auto" # TODO change to "required" once proxy supports it
elif tools:
# only set if tools is non-Null
@@ -171,11 +200,12 @@ class OpenAIClient(LLMClientBase):
if data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
for tool in data.tools:
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
if supports_structured_output(llm_config):
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
return data.model_dump(exclude_unset=True)