From 92829e7daf647a220c89c92db17d40faac300f43 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 20 Dec 2024 17:37:42 -0800 Subject: [PATCH] feat: store handle in configs (#2299) Co-authored-by: Caren Thomas --- letta/providers.py | 23 ++++++++++++++++++++--- letta/schemas/embedding_config.py | 1 + letta/schemas/llm_config.py | 1 + 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/letta/providers.py b/letta/providers.py index 83069a76..e8ebadfa 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -27,6 +27,10 @@ class Provider(BaseModel): def provider_tag(self) -> str: """String representation of the provider for display purposes""" raise NotImplementedError + + def get_handle(self, model_name: str) -> str: + return f"{self.name}/{model_name}" + class LettaProvider(Provider): @@ -40,6 +44,7 @@ class LettaProvider(Provider): model_endpoint_type="openai", model_endpoint="https://inference.memgpt.ai", context_window=16384, + handle=self.get_handle("letta-free") ) ] @@ -51,6 +56,7 @@ class LettaProvider(Provider): embedding_endpoint="https://embeddings.memgpt.ai", embedding_dim=1024, embedding_chunk_size=300, + handle=self.get_handle("letta-free") ) ] @@ -115,7 +121,7 @@ class OpenAIProvider(Provider): # continue configs.append( - LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size) + LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name)) ) # for OpenAI, sort in reverse order @@ -135,6 +141,7 @@ class OpenAIProvider(Provider): embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, embedding_chunk_size=300, + handle=self.get_handle("text-embedding-ada-002") ) ] @@ -163,6 +170,7 @@ class AnthropicProvider(Provider): model_endpoint_type="anthropic", model_endpoint=self.base_url, context_window=model["context_window"], + handle=self.get_handle(model["name"]) ) ) return configs @@ -195,6 +203,7 @@ class MistralProvider(Provider): model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["max_context_length"], + handle=self.get_handle(model["id"]) ) ) @@ -250,6 +259,7 @@ class OllamaProvider(OpenAIProvider): model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window, + handle=self.get_handle(model["name"]) ) ) return configs @@ -325,6 +335,7 @@ class OllamaProvider(OpenAIProvider): embedding_endpoint=self.base_url, embedding_dim=embedding_dim, embedding_chunk_size=300, + handle=self.get_handle(model["name"]) ) ) return configs @@ -345,7 +356,7 @@ class GroqProvider(OpenAIProvider): continue configs.append( LLMConfig( - model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"] + model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]) ) ) return configs @@ -413,6 +424,7 @@ class TogetherProvider(OpenAIProvider): model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window_size, + handle=self.get_handle(model_name) ) ) @@ -493,6 +505,7 @@ class GoogleAIProvider(Provider): model_endpoint_type="google_ai", model_endpoint=self.base_url, context_window=self.get_model_context_window(model), + handle=self.get_handle(model) ) ) return configs @@ -516,6 +529,7 @@ class GoogleAIProvider(Provider): embedding_endpoint=self.base_url, embedding_dim=768, embedding_chunk_size=300, # NOTE: max is 2048 + handle=self.get_handle(model) ) ) return configs @@ -556,7 +570,7 @@ class AzureProvider(Provider): context_window_size = self.get_model_context_window(model_name) model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) configs.append( - LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size) + LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name) ) return configs @@ -577,6 +591,7 @@ class AzureProvider(Provider): embedding_endpoint=model_endpoint, embedding_dim=768, embedding_chunk_size=300, # NOTE: max is 2048 + handle=self.get_handle(model_name) ) ) return configs @@ -610,6 +625,7 @@ class VLLMChatCompletionsProvider(Provider): model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["max_model_len"], + handle=self.get_handle(model["id"]) ) ) return configs @@ -642,6 +658,7 @@ class VLLMCompletionsProvider(Provider): model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=model["max_model_len"], + handle=self.get_handle(model["id"]) ) ) return configs diff --git a/letta/schemas/embedding_config.py b/letta/schemas/embedding_config.py index dcd80c0f..7a8236c3 100644 --- a/letta/schemas/embedding_config.py +++ b/letta/schemas/embedding_config.py @@ -43,6 +43,7 @@ class EmbeddingConfig(BaseModel): embedding_model: str = Field(..., description="The model for the embedding.") embedding_dim: int = Field(..., description="The dimension of the embedding.") embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.") + handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.") # azure only azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.") diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index ed63e766..0be4f818 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -44,6 +44,7 @@ class LLMConfig(BaseModel): True, description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.", ) + handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.") # FIXME hack to silence pydantic protected namespace warning model_config = ConfigDict(protected_namespaces=())