feat: store handle in configs (#2299)

Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
cthomas
2024-12-20 17:37:42 -08:00
committed by GitHub
parent 21aa010319
commit 92829e7daf
3 changed files with 22 additions and 3 deletions

View File

@@ -27,6 +27,10 @@ class Provider(BaseModel):
def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError
def get_handle(self, model_name: str) -> str:
return f"{self.name}/{model_name}"
class LettaProvider(Provider):
@@ -40,6 +44,7 @@ class LettaProvider(Provider):
model_endpoint_type="openai",
model_endpoint="https://inference.memgpt.ai",
context_window=16384,
handle=self.get_handle("letta-free")
)
]
@@ -51,6 +56,7 @@ class LettaProvider(Provider):
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
embedding_chunk_size=300,
handle=self.get_handle("letta-free")
)
]
@@ -115,7 +121,7 @@ class OpenAIProvider(Provider):
# continue
configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
)
# for OpenAI, sort in reverse order
@@ -135,6 +141,7 @@ class OpenAIProvider(Provider):
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002")
)
]
@@ -163,6 +170,7 @@ class AnthropicProvider(Provider):
model_endpoint_type="anthropic",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["name"])
)
)
return configs
@@ -195,6 +203,7 @@ class MistralProvider(Provider):
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
handle=self.get_handle(model["id"])
)
)
@@ -250,6 +259,7 @@ class OllamaProvider(OpenAIProvider):
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"])
)
)
return configs
@@ -325,6 +335,7 @@ class OllamaProvider(OpenAIProvider):
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
handle=self.get_handle(model["name"])
)
)
return configs
@@ -345,7 +356,7 @@ class GroqProvider(OpenAIProvider):
continue
configs.append(
LLMConfig(
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
)
)
return configs
@@ -413,6 +424,7 @@ class TogetherProvider(OpenAIProvider):
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
handle=self.get_handle(model_name)
)
)
@@ -493,6 +505,7 @@ class GoogleAIProvider(Provider):
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=self.get_model_context_window(model),
handle=self.get_handle(model)
)
)
return configs
@@ -516,6 +529,7 @@ class GoogleAIProvider(Provider):
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model)
)
)
return configs
@@ -556,7 +570,7 @@ class AzureProvider(Provider):
context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
)
return configs
@@ -577,6 +591,7 @@ class AzureProvider(Provider):
embedding_endpoint=model_endpoint,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model_name)
)
)
return configs
@@ -610,6 +625,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
)
)
return configs
@@ -642,6 +658,7 @@ class VLLMCompletionsProvider(Provider):
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
)
)
return configs

View File

@@ -43,6 +43,7 @@ class EmbeddingConfig(BaseModel):
embedding_model: str = Field(..., description="The model for the embedding.")
embedding_dim: int = Field(..., description="The dimension of the embedding.")
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.")
# azure only
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")

View File

@@ -44,6 +44,7 @@ class LLMConfig(BaseModel):
True,
description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.",
)
handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.")
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())