fix: make togetherai nebius xai etc usable via the openaiprovider (#1981)

Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Kevin Lin <kl2806@columbia.edu>
This commit is contained in:
Charles Packer
2025-05-09 10:50:55 -07:00
committed by GitHub
parent 782971f0dc
commit 8bb194541e
9 changed files with 259 additions and 70 deletions

View File

@@ -215,6 +215,9 @@ def create(
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
# NOTE: needs to be true for OpenAI proxies that use the `reasoning_content` field
# For example, DeepSeek, or LM Studio
expect_reasoning_content=False,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
@@ -272,6 +275,9 @@ def create(
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
# TODO turn on to support reasoning content from xAI reasoners:
# https://docs.x.ai/docs/guides/reasoning#reasoning
expect_reasoning_content=False,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
@@ -486,7 +492,10 @@ def create(
if stream:
raise NotImplementedError(f"Streaming not yet implemented for TogetherAI (via the /completions endpoint).")
if model_settings.together_api_key is None and llm_config.model_endpoint == "https://api.together.ai/v1/completions":
if model_settings.together_api_key is None and (
llm_config.model_endpoint == "https://api.together.ai/v1/completions"
or llm_config.model_endpoint == "https://api.together.xyz/v1/completions"
):
raise LettaConfigurationError(message="TogetherAI key is missing from letta config file", missing_fields=["together_api_key"])
return get_chat_completion(
@@ -560,6 +569,8 @@ def create(
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
# TODO should we toggle for R1 vs V3?
expect_reasoning_content=True,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False

View File

@@ -8,7 +8,13 @@ from letta.constants import LETTA_MODEL_ENDPOINT
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.helpers.datetime_helpers import timestamp_to_datetime
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.llm_api.openai_client import accepts_developer_role, supports_parallel_tool_calling, supports_temperature_param
from letta.llm_api.openai_client import (
accepts_developer_role,
requires_auto_tool_choice,
supports_parallel_tool_calling,
supports_structured_output,
supports_temperature_param,
)
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
@@ -50,9 +56,7 @@ def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None
raise ValueError("No API key provided")
def openai_get_model_list(
url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
) -> dict:
def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool = False, extra_params: Optional[dict] = None) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from letta.utils import printd
@@ -154,7 +158,10 @@ def build_openai_chat_completions_request(
elif function_call not in ["none", "auto", "required"]:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=function_call))
else:
tool_choice = function_call
if requires_auto_tool_choice(llm_config):
tool_choice = "auto"
else:
tool_choice = function_call
data = ChatCompletionRequest(
model=model,
messages=openai_message_list,
@@ -197,12 +204,13 @@ def build_openai_chat_completions_request(
if use_structured_output and data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
for tool in data.tools:
try:
# tool["function"] = convert_to_structured_output(tool["function"])
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
if supports_structured_output(llm_config):
try:
# tool["function"] = convert_to_structured_output(tool["function"])
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
return data
@@ -221,7 +229,7 @@ def openai_chat_completions_process_stream(
expect_reasoning_content: bool = True,
name: Optional[str] = None,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
"""Process a streaming completion response, and return a ChatCompletionResponse at the end.
To "stream" the response in Letta, we want to call a streaming-compatible interface function
on the chunks received from the OpenAI-compatible server POST SSE response.
@@ -293,6 +301,9 @@ def openai_chat_completions_process_stream(
url=url, api_key=api_key, chat_completion_request=chat_completion_request
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
if chat_completion_chunk.choices is None or len(chat_completion_chunk.choices) == 0:
warnings.warn(f"No choices in chunk: {chat_completion_chunk}")
continue
# NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream
if override_tool_call_id:
@@ -429,6 +440,9 @@ def openai_chat_completions_process_stream(
except Exception as e:
if stream_interface:
stream_interface.stream_end()
import traceback
traceback.print_exc()
logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
raise e
finally:
@@ -463,14 +477,27 @@ def openai_chat_completions_request_stream(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
fix_url: bool = False,
) -> Generator[ChatCompletionChunkResponse, None, None]:
# In some cases we may want to double-check the URL and do basic correction, eg:
# In Letta config the address for vLLM is w/o a /v1 suffix for simplicity
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
if fix_url:
if not url.endswith("/v1"):
url = smart_urljoin(url, "v1")
data = prepare_openai_payload(chat_completion_request)
data["stream"] = True
client = OpenAI(api_key=api_key, base_url=url, max_retries=0)
stream = client.chat.completions.create(**data)
for chunk in stream:
# TODO: Use the native OpenAI objects here?
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
try:
stream = client.chat.completions.create(**data)
for chunk in stream:
# TODO: Use the native OpenAI objects here?
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
except Exception as e:
print(f"Error request stream from /v1/chat/completions, url={url}, data={data}:\n{e}")
raise e
def openai_chat_completions_request(

View File

@@ -75,6 +75,35 @@ def supports_parallel_tool_calling(model: str) -> bool:
return True
# TODO move into LLMConfig as a field?
def supports_structured_output(llm_config: LLMConfig) -> bool:
"""Certain providers don't support structured output."""
# FIXME pretty hacky - turn off for providers we know users will use,
# but also don't support structured output
if "nebius.com" in llm_config.model_endpoint:
return False
else:
return True
# TODO move into LLMConfig as a field?
def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
"""Certain providers require the tool choice to be set to 'auto'."""
if "nebius.com" in llm_config.model_endpoint:
return True
# proxy also has this issue (FIXME check)
elif llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
return True
# same with vLLM (FIXME check)
elif llm_config.handle and "vllm" in llm_config.handle:
return True
else:
# will use "required" instead of "auto"
return False
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = None
@@ -136,7 +165,7 @@ class OpenAIClient(LLMClientBase):
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
if requires_auto_tool_choice(llm_config):
tool_choice = "auto" # TODO change to "required" once proxy supports it
elif tools:
# only set if tools is non-Null
@@ -171,11 +200,12 @@ class OpenAIClient(LLMClientBase):
if data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
for tool in data.tools:
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
if supports_structured_output(llm_config):
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
return data.model_dump(exclude_unset=True)

View File

@@ -24,7 +24,6 @@ class LLMConfig(BaseModel):
max_tokens (int): The maximum number of tokens to generate.
"""
# TODO: 🤮 don't default to a vendor! bug city!
model: str = Field(..., description="LLM model name. ")
model_endpoint_type: Literal[
"openai",

View File

@@ -1,5 +1,5 @@
import datetime
from typing import Dict, List, Literal, Optional, Union
from typing import List, Literal, Optional, Union
from pydantic import BaseModel
@@ -27,6 +27,7 @@ class LogProbToken(BaseModel):
bytes: Optional[List[int]]
# Legacy?
class MessageContentLogProb(BaseModel):
token: str
logprob: float
@@ -34,6 +35,25 @@ class MessageContentLogProb(BaseModel):
top_logprobs: Optional[List[LogProbToken]]
class TopLogprob(BaseModel):
token: str
bytes: Optional[List[int]] = None
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: Optional[List[int]] = None
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
content: Optional[List[ChatCompletionTokenLogprob]] = None
refusal: Optional[List[ChatCompletionTokenLogprob]] = None
class Message(BaseModel):
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
@@ -49,7 +69,7 @@ class Choice(BaseModel):
finish_reason: str
index: int
message: Message
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
logprobs: Optional[ChoiceLogprobs] = None
seed: Optional[int] = None # found in TogetherAI
@@ -134,7 +154,7 @@ class ChatCompletionResponse(BaseModel):
class FunctionCallDelta(BaseModel):
# arguments: Optional[str] = None
name: Optional[str] = None
arguments: str
arguments: Optional[str] = None
# name: str
@@ -179,7 +199,7 @@ class ChunkChoice(BaseModel):
finish_reason: Optional[str] = None # NOTE: when streaming will be null
index: int
delta: MessageDelta
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
logprobs: Optional[ChoiceLogprobs] = None
class ChatCompletionChunkResponse(BaseModel):

View File

@@ -4,7 +4,7 @@ from typing import List, Literal, Optional
from pydantic import BaseModel, Field, model_validator
from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
@@ -57,7 +57,7 @@ class Provider(ProviderBase):
"""String representation of the provider for display purposes"""
raise NotImplementedError
def get_handle(self, model_name: str, is_embedding: bool = False) -> str:
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str:
"""
Get the handle for a model, with support for custom overrides.
@@ -68,11 +68,13 @@ class Provider(ProviderBase):
Returns:
str: The handle for the model.
"""
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
if self.name in overrides and model_name in overrides[self.name]:
model_name = overrides[self.name][model_name]
base_name = base_name if base_name else self.name
return f"{self.name}/{model_name}"
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
if base_name in overrides and model_name in overrides[base_name]:
model_name = overrides[base_name][model_name]
return f"{base_name}/{model_name}"
def cast_to_subtype(self):
match (self.provider_type):
@@ -162,21 +164,34 @@ class OpenAIProvider(Provider):
openai_check_valid_api_key(self.base_url, self.api_key)
def list_llm_models(self) -> List[LLMConfig]:
def _get_models(self) -> List[dict]:
from letta.llm_api.openai import openai_get_model_list
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
# TogetherAI's response is missing the 'data' field
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
# Similar to Nebius
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
response = openai_get_model_list(
self.base_url,
api_key=self.api_key,
extra_params=extra_params,
# fix_url=True, # NOTE: make sure together ends with /v1
)
if "data" in response:
data = response["data"]
else:
# TogetherAI's response is missing the 'data' field
data = response
return data
def list_llm_models(self) -> List[LLMConfig]:
data = self._get_models()
configs = []
for model in data:
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
@@ -192,8 +207,8 @@ class OpenAIProvider(Provider):
continue
# TogetherAI includes the type, which we can use to filter out embedding models
if self.base_url == "https://api.together.ai/v1":
if "type" in model and model["type"] != "chat":
if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url:
if "type" in model and model["type"] not in ["chat", "language"]:
continue
# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
@@ -211,11 +226,25 @@ class OpenAIProvider(Provider):
continue
if model["config"]["chat_template"] is None:
continue
if model["config"]["chat_template"] is not None and "tools" not in model["config"]["chat_template"]:
# NOTE: this is a hack to filter out models that don't support tool calling
continue
if "tools" not in model["config"]["chat_template"]:
continue
# if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
# continue
if "nebius.com" in self.base_url:
# Nebius includes the type, which we can use to filter for text models
try:
model_type = model["architecture"]["modality"]
if model_type not in ["text->text", "text+image->text"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
except KeyError:
print(f"Couldn't access architecture type field, skipping model:\n{model}")
continue
# for openai, filter models
if self.base_url == "https://api.openai.com/v1":
allowed_types = ["gpt-4", "o1", "o3"]
@@ -235,13 +264,19 @@ class OpenAIProvider(Provider):
if skip:
continue
# set the handle to openai-proxy if the base URL isn't OpenAI
if self.base_url != "https://api.openai.com/v1":
handle = self.get_handle(model_name, base_name="openai-proxy")
else:
handle = self.get_handle(model_name)
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
handle=handle,
provider_name=self.name,
provider_category=self.provider_category,
)
@@ -256,33 +291,87 @@ class OpenAIProvider(Provider):
def list_embedding_models(self) -> List[EmbeddingConfig]:
# TODO: actually automatically list models
return [
EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-large",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
),
]
if self.base_url == "https://api.openai.com/v1":
# TODO: actually automatically list models for OpenAI
return [
EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-large",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
),
]
else:
# Actually attempt to list
data = self._get_models()
configs = []
for model in data:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
if "context_length" in model:
# Context length is returned in Nebius as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
# We need the context length for embeddings too
if not context_window_size:
continue
if "nebius.com" in self.base_url:
# Nebius includes the type, which we can use to filter for embedidng models
try:
model_type = model["architecture"]["modality"]
if model_type not in ["text->embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
except KeyError:
print(f"Couldn't access architecture type field, skipping model:\n{model}")
continue
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
else:
# For other providers we should skip by default, since we don't want to assume embeddings are supported
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type=self.provider_type,
embedding_endpoint=self.base_url,
embedding_dim=context_window_size,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model, is_embedding=True),
)
)
return configs
def get_model_context_window_size(self, model_name: str):
if model_name in LLM_MAX_TOKENS:

View File

@@ -482,6 +482,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
data: {"function_return": "None", "status": "success", "date": "2024-02-29T06:07:50.847262+00:00"}
"""
if not chunk.choices or len(chunk.choices) == 0:
warnings.warn(f"No choices in chunk: {chunk}")
return None
choice = chunk.choices[0]
message_delta = choice.delta
otid = Message.generate_otid_from_id(message_id, message_index)

View File

@@ -1219,6 +1219,9 @@ class SyncServer(Server):
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
import traceback
traceback.print_exc()
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
llm_models.extend(self.get_local_llm_configs())

View File

@@ -2,7 +2,7 @@ import os
from pathlib import Path
from typing import Optional
from pydantic import Field
from pydantic import AliasChoices, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from letta.local_llm.constants import DEFAULT_WRAPPER_NAME
@@ -70,7 +70,13 @@ class ModelSettings(BaseSettings):
# openai
openai_api_key: Optional[str] = None
openai_api_base: str = "https://api.openai.com/v1"
openai_api_base: str = Field(
default="https://api.openai.com/v1",
# NOTE: We previously used OPENAI_API_BASE, but this was deprecated in favor of OPENAI_BASE_URL
# preferred first, fallback second
# env=["OPENAI_BASE_URL", "OPENAI_API_BASE"], # pydantic-settings v2
validation_alias=AliasChoices("OPENAI_BASE_URL", "OPENAI_API_BASE"), # pydantic-settings v1
)
# deepseek
deepseek_api_key: Optional[str] = None