feat: store handle in configs (#2299)
Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
@@ -27,6 +27,10 @@ class Provider(BaseModel):
|
||||
def provider_tag(self) -> str:
|
||||
"""String representation of the provider for display purposes"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_handle(self, model_name: str) -> str:
|
||||
return f"{self.name}/{model_name}"
|
||||
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
@@ -40,6 +44,7 @@ class LettaProvider(Provider):
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://inference.memgpt.ai",
|
||||
context_window=16384,
|
||||
handle=self.get_handle("letta-free")
|
||||
)
|
||||
]
|
||||
|
||||
@@ -51,6 +56,7 @@ class LettaProvider(Provider):
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_dim=1024,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("letta-free")
|
||||
)
|
||||
]
|
||||
|
||||
@@ -115,7 +121,7 @@ class OpenAIProvider(Provider):
|
||||
# continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
|
||||
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
|
||||
)
|
||||
|
||||
# for OpenAI, sort in reverse order
|
||||
@@ -135,6 +141,7 @@ class OpenAIProvider(Provider):
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-ada-002")
|
||||
)
|
||||
]
|
||||
|
||||
@@ -163,6 +170,7 @@ class AnthropicProvider(Provider):
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["name"])
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -195,6 +203,7 @@ class MistralProvider(Provider):
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_context_length"],
|
||||
handle=self.get_handle(model["id"])
|
||||
)
|
||||
)
|
||||
|
||||
@@ -250,6 +259,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model["name"])
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -325,6 +335,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle(model["name"])
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -345,7 +356,7 @@ class GroqProvider(OpenAIProvider):
|
||||
continue
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
|
||||
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -413,6 +424,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -493,6 +505,7 @@ class GoogleAIProvider(Provider):
|
||||
model_endpoint_type="google_ai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=self.get_model_context_window(model),
|
||||
handle=self.get_handle(model)
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -516,6 +529,7 @@ class GoogleAIProvider(Provider):
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=300, # NOTE: max is 2048
|
||||
handle=self.get_handle(model)
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -556,7 +570,7 @@ class AzureProvider(Provider):
|
||||
context_window_size = self.get_model_context_window(model_name)
|
||||
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
||||
configs.append(
|
||||
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
|
||||
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -577,6 +591,7 @@ class AzureProvider(Provider):
|
||||
embedding_endpoint=model_endpoint,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=300, # NOTE: max is 2048
|
||||
handle=self.get_handle(model_name)
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -610,6 +625,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"])
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -642,6 +658,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"])
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -43,6 +43,7 @@ class EmbeddingConfig(BaseModel):
|
||||
embedding_model: str = Field(..., description="The model for the embedding.")
|
||||
embedding_dim: int = Field(..., description="The dimension of the embedding.")
|
||||
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
|
||||
handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.")
|
||||
|
||||
# azure only
|
||||
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")
|
||||
|
||||
@@ -44,6 +44,7 @@ class LLMConfig(BaseModel):
|
||||
True,
|
||||
description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.",
|
||||
)
|
||||
handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.")
|
||||
|
||||
# FIXME hack to silence pydantic protected namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
Reference in New Issue
Block a user