From 878511b4600d755821064e69649fbc69fb6a3520 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 23 May 2025 09:09:47 -0700 Subject: [PATCH] feat(asyncify): migrate list models (#2369) --- letta/schemas/providers.py | 2 +- letta/server/server.py | 119 +++++++++++++++++++++++++++++++++++-- 2 files changed, 116 insertions(+), 5 deletions(-) diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 9f17737a..86f919eb 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -54,7 +54,7 @@ class Provider(ProviderBase): return [] async def list_embedding_models_async(self) -> List[EmbeddingConfig]: - return [] + return self.list_embedding_models() def get_model_context_window(self, model_name: str) -> Optional[int]: raise NotImplementedError diff --git a/letta/server/server.py b/letta/server/server.py index 45cdc882..80001a02 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -741,6 +741,13 @@ class SyncServer(Server): self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs) return self._llm_config_cache[key] + @trace_method + async def get_cached_llm_config_async(self, actor: User, **kwargs): + key = make_key(**kwargs) + if key not in self._llm_config_cache: + self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs) + return self._llm_config_cache[key] + @trace_method def get_cached_embedding_config(self, actor: User, **kwargs): key = make_key(**kwargs) @@ -748,6 +755,13 @@ class SyncServer(Server): self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs) return self._embedding_config_cache[key] + @trace_method + async def get_cached_embedding_config_async(self, actor: User, **kwargs): + key = make_key(**kwargs) + if key not in self._embedding_config_cache: + self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs) + return self._embedding_config_cache[key] + @trace_method def create_agent( self, @@ -815,7 +829,7 @@ class SyncServer(Server): "enable_reasoner": request.enable_reasoner, } log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = self.get_cached_llm_config(actor=actor, **config_params) + request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) log_event(name="end get_cached_llm_config", attributes=config_params) if request.embedding_config is None: @@ -826,7 +840,7 @@ class SyncServer(Server): "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, } log_event(name="start get_cached_embedding_config", attributes=embedding_config_params) - request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params) + request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params) log_event(name="end get_cached_embedding_config", attributes=embedding_config_params) log_event(name="start create_agent db") @@ -877,10 +891,10 @@ class SyncServer(Server): actor: User, ) -> AgentState: if request.model is not None: - request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor) + request.llm_config = await self.get_llm_config_from_handle_async(handle=request.model, actor=actor) if request.embedding is not None: - request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor) + request.embedding_config = await self.get_embedding_config_from_handle_async(handle=request.embedding, actor=actor) if request.enable_sleeptime: agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) @@ -1568,6 +1582,61 @@ class SyncServer(Server): return llm_config + @trace_method + async def get_llm_config_from_handle_async( + self, + actor: User, + handle: str, + context_window_limit: Optional[int] = None, + max_tokens: Optional[int] = None, + max_reasoning_tokens: Optional[int] = None, + enable_reasoner: Optional[bool] = None, + ) -> LLMConfig: + try: + provider_name, model_name = handle.split("/", 1) + provider = await self.get_provider_from_name_async(provider_name, actor) + + all_llm_configs = await provider.list_llm_models_async() + llm_configs = [config for config in all_llm_configs if config.handle == handle] + if not llm_configs: + llm_configs = [config for config in all_llm_configs if config.model == model_name] + if not llm_configs: + available_handles = [config.handle for config in all_llm_configs] + raise HandleNotFoundError(handle, available_handles) + except ValueError as e: + llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle] + if not llm_configs: + llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name] + if not llm_configs: + raise e + + if len(llm_configs) == 1: + llm_config = llm_configs[0] + elif len(llm_configs) > 1: + raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") + else: + llm_config = llm_configs[0] + + if context_window_limit is not None: + if context_window_limit > llm_config.context_window: + raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})") + llm_config.context_window = context_window_limit + else: + llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit) + + if max_tokens is not None: + llm_config.max_tokens = max_tokens + if max_reasoning_tokens is not None: + if not max_tokens or max_reasoning_tokens > max_tokens: + raise ValueError(f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})") + llm_config.max_reasoning_tokens = max_reasoning_tokens + if enable_reasoner is not None: + llm_config.enable_reasoner = enable_reasoner + if enable_reasoner and llm_config.model_endpoint_type == "anthropic": + llm_config.put_inner_thoughts_in_kwargs = False + + return llm_config + @trace_method def get_embedding_config_from_handle( self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE @@ -1597,6 +1666,36 @@ class SyncServer(Server): return embedding_config + @trace_method + async def get_embedding_config_from_handle_async( + self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE + ) -> EmbeddingConfig: + try: + provider_name, model_name = handle.split("/", 1) + provider = await self.get_provider_from_name_async(provider_name, actor) + + all_embedding_configs = await provider.list_embedding_models_async() + embedding_configs = [config for config in all_embedding_configs if config.handle == handle] + if not embedding_configs: + raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}") + except ValueError as e: + # search local configs + embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle] + if not embedding_configs: + raise e + + if len(embedding_configs) == 1: + embedding_config = embedding_configs[0] + elif len(embedding_configs) > 1: + raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}") + else: + embedding_config = embedding_configs[0] + + if embedding_chunk_size: + embedding_config.embedding_chunk_size = embedding_chunk_size + + return embedding_config + def get_provider_from_name(self, provider_name: str, actor: User) -> Provider: providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name] if not providers: @@ -1608,6 +1707,18 @@ class SyncServer(Server): return provider + async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider: + all_providers = await self.get_enabled_providers_async(actor) + providers = [provider for provider in all_providers if provider.name == provider_name] + if not providers: + raise ValueError(f"Provider {provider_name} is not supported") + elif len(providers) > 1: + raise ValueError(f"Multiple providers with name {provider_name} supported") + else: + provider = providers[0] + + return provider + def get_local_llm_configs(self): llm_models = [] try: