fix: make togetherai nebius xai etc usable via the openaiprovider (#1981)
Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Kevin Lin <kl2806@columbia.edu>
This commit is contained in:
@@ -215,6 +215,9 @@ def create(
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
# NOTE: needs to be true for OpenAI proxies that use the `reasoning_content` field
|
||||
# For example, DeepSeek, or LM Studio
|
||||
expect_reasoning_content=False,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
@@ -272,6 +275,9 @@ def create(
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
# TODO turn on to support reasoning content from xAI reasoners:
|
||||
# https://docs.x.ai/docs/guides/reasoning#reasoning
|
||||
expect_reasoning_content=False,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
@@ -486,7 +492,10 @@ def create(
|
||||
if stream:
|
||||
raise NotImplementedError(f"Streaming not yet implemented for TogetherAI (via the /completions endpoint).")
|
||||
|
||||
if model_settings.together_api_key is None and llm_config.model_endpoint == "https://api.together.ai/v1/completions":
|
||||
if model_settings.together_api_key is None and (
|
||||
llm_config.model_endpoint == "https://api.together.ai/v1/completions"
|
||||
or llm_config.model_endpoint == "https://api.together.xyz/v1/completions"
|
||||
):
|
||||
raise LettaConfigurationError(message="TogetherAI key is missing from letta config file", missing_fields=["together_api_key"])
|
||||
|
||||
return get_chat_completion(
|
||||
@@ -560,6 +569,8 @@ def create(
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
# TODO should we toggle for R1 vs V3?
|
||||
expect_reasoning_content=True,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
|
||||
@@ -8,7 +8,13 @@ from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.helpers.datetime_helpers import timestamp_to_datetime
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.llm_api.openai_client import accepts_developer_role, supports_parallel_tool_calling, supports_temperature_param
|
||||
from letta.llm_api.openai_client import (
|
||||
accepts_developer_role,
|
||||
requires_auto_tool_choice,
|
||||
supports_parallel_tool_calling,
|
||||
supports_structured_output,
|
||||
supports_temperature_param,
|
||||
)
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
@@ -50,9 +56,7 @@ def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
def openai_get_model_list(
|
||||
url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
|
||||
) -> dict:
|
||||
def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool = False, extra_params: Optional[dict] = None) -> dict:
|
||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||
from letta.utils import printd
|
||||
|
||||
@@ -154,7 +158,10 @@ def build_openai_chat_completions_request(
|
||||
elif function_call not in ["none", "auto", "required"]:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=function_call))
|
||||
else:
|
||||
tool_choice = function_call
|
||||
if requires_auto_tool_choice(llm_config):
|
||||
tool_choice = "auto"
|
||||
else:
|
||||
tool_choice = function_call
|
||||
data = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=openai_message_list,
|
||||
@@ -197,12 +204,13 @@ def build_openai_chat_completions_request(
|
||||
if use_structured_output and data.tools is not None and len(data.tools) > 0:
|
||||
# Convert to structured output style (which has 'strict' and no optionals)
|
||||
for tool in data.tools:
|
||||
try:
|
||||
# tool["function"] = convert_to_structured_output(tool["function"])
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
if supports_structured_output(llm_config):
|
||||
try:
|
||||
# tool["function"] = convert_to_structured_output(tool["function"])
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
return data
|
||||
|
||||
|
||||
@@ -221,7 +229,7 @@ def openai_chat_completions_process_stream(
|
||||
expect_reasoning_content: bool = True,
|
||||
name: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
|
||||
"""Process a streaming completion response, and return a ChatCompletionResponse at the end.
|
||||
|
||||
To "stream" the response in Letta, we want to call a streaming-compatible interface function
|
||||
on the chunks received from the OpenAI-compatible server POST SSE response.
|
||||
@@ -293,6 +301,9 @@ def openai_chat_completions_process_stream(
|
||||
url=url, api_key=api_key, chat_completion_request=chat_completion_request
|
||||
):
|
||||
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
|
||||
if chat_completion_chunk.choices is None or len(chat_completion_chunk.choices) == 0:
|
||||
warnings.warn(f"No choices in chunk: {chat_completion_chunk}")
|
||||
continue
|
||||
|
||||
# NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream
|
||||
if override_tool_call_id:
|
||||
@@ -429,6 +440,9 @@ def openai_chat_completions_process_stream(
|
||||
except Exception as e:
|
||||
if stream_interface:
|
||||
stream_interface.stream_end()
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
@@ -463,14 +477,27 @@ def openai_chat_completions_request_stream(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
fix_url: bool = False,
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
|
||||
# In some cases we may want to double-check the URL and do basic correction, eg:
|
||||
# In Letta config the address for vLLM is w/o a /v1 suffix for simplicity
|
||||
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
|
||||
if fix_url:
|
||||
if not url.endswith("/v1"):
|
||||
url = smart_urljoin(url, "v1")
|
||||
|
||||
data = prepare_openai_payload(chat_completion_request)
|
||||
data["stream"] = True
|
||||
client = OpenAI(api_key=api_key, base_url=url, max_retries=0)
|
||||
stream = client.chat.completions.create(**data)
|
||||
for chunk in stream:
|
||||
# TODO: Use the native OpenAI objects here?
|
||||
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
|
||||
try:
|
||||
stream = client.chat.completions.create(**data)
|
||||
for chunk in stream:
|
||||
# TODO: Use the native OpenAI objects here?
|
||||
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
|
||||
except Exception as e:
|
||||
print(f"Error request stream from /v1/chat/completions, url={url}, data={data}:\n{e}")
|
||||
raise e
|
||||
|
||||
|
||||
def openai_chat_completions_request(
|
||||
|
||||
@@ -75,6 +75,35 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# TODO move into LLMConfig as a field?
|
||||
def supports_structured_output(llm_config: LLMConfig) -> bool:
|
||||
"""Certain providers don't support structured output."""
|
||||
|
||||
# FIXME pretty hacky - turn off for providers we know users will use,
|
||||
# but also don't support structured output
|
||||
if "nebius.com" in llm_config.model_endpoint:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
# TODO move into LLMConfig as a field?
|
||||
def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
|
||||
"""Certain providers require the tool choice to be set to 'auto'."""
|
||||
|
||||
if "nebius.com" in llm_config.model_endpoint:
|
||||
return True
|
||||
# proxy also has this issue (FIXME check)
|
||||
elif llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
return True
|
||||
# same with vLLM (FIXME check)
|
||||
elif llm_config.handle and "vllm" in llm_config.handle:
|
||||
return True
|
||||
else:
|
||||
# will use "required" instead of "auto"
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = None
|
||||
@@ -136,7 +165,7 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
if requires_auto_tool_choice(llm_config):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
@@ -171,11 +200,12 @@ class OpenAIClient(LLMClientBase):
|
||||
if data.tools is not None and len(data.tools) > 0:
|
||||
# Convert to structured output style (which has 'strict' and no optionals)
|
||||
for tool in data.tools:
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
if supports_structured_output(llm_config):
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
|
||||
return data.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ class LLMConfig(BaseModel):
|
||||
max_tokens (int): The maximum number of tokens to generate.
|
||||
"""
|
||||
|
||||
# TODO: 🤮 don't default to a vendor! bug city!
|
||||
model: str = Field(..., description="LLM model name. ")
|
||||
model_endpoint_type: Literal[
|
||||
"openai",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import datetime
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -27,6 +27,7 @@ class LogProbToken(BaseModel):
|
||||
bytes: Optional[List[int]]
|
||||
|
||||
|
||||
# Legacy?
|
||||
class MessageContentLogProb(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
@@ -34,6 +35,25 @@ class MessageContentLogProb(BaseModel):
|
||||
top_logprobs: Optional[List[LogProbToken]]
|
||||
|
||||
|
||||
class TopLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
logprob: float
|
||||
|
||||
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
|
||||
class ChoiceLogprobs(BaseModel):
|
||||
content: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||
|
||||
refusal: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
@@ -49,7 +69,7 @@ class Choice(BaseModel):
|
||||
finish_reason: str
|
||||
index: int
|
||||
message: Message
|
||||
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
|
||||
logprobs: Optional[ChoiceLogprobs] = None
|
||||
seed: Optional[int] = None # found in TogetherAI
|
||||
|
||||
|
||||
@@ -134,7 +154,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
class FunctionCallDelta(BaseModel):
|
||||
# arguments: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
arguments: str
|
||||
arguments: Optional[str] = None
|
||||
# name: str
|
||||
|
||||
|
||||
@@ -179,7 +199,7 @@ class ChunkChoice(BaseModel):
|
||||
finish_reason: Optional[str] = None # NOTE: when streaming will be null
|
||||
index: int
|
||||
delta: MessageDelta
|
||||
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
|
||||
logprobs: Optional[ChoiceLogprobs] = None
|
||||
|
||||
|
||||
class ChatCompletionChunkResponse(BaseModel):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -57,7 +57,7 @@ class Provider(ProviderBase):
|
||||
"""String representation of the provider for display purposes"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_handle(self, model_name: str, is_embedding: bool = False) -> str:
|
||||
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get the handle for a model, with support for custom overrides.
|
||||
|
||||
@@ -68,11 +68,13 @@ class Provider(ProviderBase):
|
||||
Returns:
|
||||
str: The handle for the model.
|
||||
"""
|
||||
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
|
||||
if self.name in overrides and model_name in overrides[self.name]:
|
||||
model_name = overrides[self.name][model_name]
|
||||
base_name = base_name if base_name else self.name
|
||||
|
||||
return f"{self.name}/{model_name}"
|
||||
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
|
||||
if base_name in overrides and model_name in overrides[base_name]:
|
||||
model_name = overrides[base_name][model_name]
|
||||
|
||||
return f"{base_name}/{model_name}"
|
||||
|
||||
def cast_to_subtype(self):
|
||||
match (self.provider_type):
|
||||
@@ -162,21 +164,34 @@ class OpenAIProvider(Provider):
|
||||
|
||||
openai_check_valid_api_key(self.base_url, self.api_key)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
def _get_models(self) -> List[dict]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
|
||||
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
# Similar to Nebius
|
||||
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
||||
|
||||
response = openai_get_model_list(
|
||||
self.base_url,
|
||||
api_key=self.api_key,
|
||||
extra_params=extra_params,
|
||||
# fix_url=True, # NOTE: make sure together ends with /v1
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
else:
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
data = response
|
||||
|
||||
return data
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
data = self._get_models()
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
||||
@@ -192,8 +207,8 @@ class OpenAIProvider(Provider):
|
||||
continue
|
||||
|
||||
# TogetherAI includes the type, which we can use to filter out embedding models
|
||||
if self.base_url == "https://api.together.ai/v1":
|
||||
if "type" in model and model["type"] != "chat":
|
||||
if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url:
|
||||
if "type" in model and model["type"] not in ["chat", "language"]:
|
||||
continue
|
||||
|
||||
# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
|
||||
@@ -211,11 +226,25 @@ class OpenAIProvider(Provider):
|
||||
continue
|
||||
if model["config"]["chat_template"] is None:
|
||||
continue
|
||||
if model["config"]["chat_template"] is not None and "tools" not in model["config"]["chat_template"]:
|
||||
# NOTE: this is a hack to filter out models that don't support tool calling
|
||||
continue
|
||||
if "tools" not in model["config"]["chat_template"]:
|
||||
continue
|
||||
# if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
|
||||
# continue
|
||||
|
||||
if "nebius.com" in self.base_url:
|
||||
# Nebius includes the type, which we can use to filter for text models
|
||||
try:
|
||||
model_type = model["architecture"]["modality"]
|
||||
if model_type not in ["text->text", "text+image->text"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
except KeyError:
|
||||
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
||||
continue
|
||||
|
||||
# for openai, filter models
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
allowed_types = ["gpt-4", "o1", "o3"]
|
||||
@@ -235,13 +264,19 @@ class OpenAIProvider(Provider):
|
||||
if skip:
|
||||
continue
|
||||
|
||||
# set the handle to openai-proxy if the base URL isn't OpenAI
|
||||
if self.base_url != "https://api.openai.com/v1":
|
||||
handle = self.get_handle(model_name, base_name="openai-proxy")
|
||||
else:
|
||||
handle = self.get_handle(model_name)
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
handle=handle,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -256,33 +291,87 @@ class OpenAIProvider(Provider):
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
|
||||
# TODO: actually automatically list models
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-large",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
|
||||
),
|
||||
]
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
# TODO: actually automatically list models for OpenAI
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-large",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
|
||||
),
|
||||
]
|
||||
|
||||
else:
|
||||
# Actually attempt to list
|
||||
data = self._get_models()
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "context_length" in model:
|
||||
# Context length is returned in Nebius as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
# We need the context length for embeddings too
|
||||
if not context_window_size:
|
||||
continue
|
||||
|
||||
if "nebius.com" in self.base_url:
|
||||
# Nebius includes the type, which we can use to filter for embedidng models
|
||||
try:
|
||||
model_type = model["architecture"]["modality"]
|
||||
if model_type not in ["text->embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
except KeyError:
|
||||
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
||||
continue
|
||||
|
||||
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
|
||||
# TogetherAI includes the type, which we can use to filter for embedding models
|
||||
if "type" in model and model["type"] not in ["embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# For other providers we should skip by default, since we don't want to assume embeddings are supported
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type=self.provider_type,
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=context_window_size,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model, is_embedding=True),
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def get_model_context_window_size(self, model_name: str):
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
|
||||
@@ -482,6 +482,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
data: {"function_return": "None", "status": "success", "date": "2024-02-29T06:07:50.847262+00:00"}
|
||||
"""
|
||||
if not chunk.choices or len(chunk.choices) == 0:
|
||||
warnings.warn(f"No choices in chunk: {chunk}")
|
||||
return None
|
||||
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
otid = Message.generate_otid_from_id(message_id, message_index)
|
||||
|
||||
@@ -1219,6 +1219,9 @@ class SyncServer(Server):
|
||||
try:
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
|
||||
|
||||
llm_models.extend(self.get_local_llm_configs())
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from letta.local_llm.constants import DEFAULT_WRAPPER_NAME
|
||||
@@ -70,7 +70,13 @@ class ModelSettings(BaseSettings):
|
||||
|
||||
# openai
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_api_base: str = "https://api.openai.com/v1"
|
||||
openai_api_base: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
# NOTE: We previously used OPENAI_API_BASE, but this was deprecated in favor of OPENAI_BASE_URL
|
||||
# preferred first, fallback second
|
||||
# env=["OPENAI_BASE_URL", "OPENAI_API_BASE"], # pydantic-settings v2
|
||||
validation_alias=AliasChoices("OPENAI_BASE_URL", "OPENAI_API_BASE"), # pydantic-settings v1
|
||||
)
|
||||
|
||||
# deepseek
|
||||
deepseek_api_key: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user