fix: undo recent changes to handle matching (#5808)
This commit is contained in:
committed by
Caren Thomas
parent
655c9489d8
commit
f37197f00d
@@ -1034,67 +1034,92 @@ class SyncServer(object):
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
enable_reasoner: Optional[bool] = None,
|
||||
) -> LLMConfig:
|
||||
"""String match the `handle` to the available configs"""
|
||||
matched_llm_config = None
|
||||
available_handles = []
|
||||
# Get all enabled providers (including BYOK providers from database)
|
||||
providers = await self.get_enabled_providers_async(actor=actor)
|
||||
for provider in providers:
|
||||
llm_configs = await self.list_llm_models_async(actor=actor)
|
||||
for llm_config in llm_configs:
|
||||
available_handles.append(llm_config.handle)
|
||||
if llm_config.handle == handle:
|
||||
matched_llm_config = llm_config
|
||||
break
|
||||
if not matched_llm_config:
|
||||
raise HandleNotFoundError(handle, available_handles)
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_llm_configs = await provider.list_llm_models_async()
|
||||
llm_configs = [config for config in all_llm_configs if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in all_llm_configs if config.model == model_name]
|
||||
if not llm_configs:
|
||||
available_handles = [config.handle for config in all_llm_configs]
|
||||
raise HandleNotFoundError(handle, available_handles)
|
||||
except ValueError as e:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise e
|
||||
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
elif len(llm_configs) > 1:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
|
||||
if context_window_limit is not None:
|
||||
if context_window_limit > matched_llm_config.context_window:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Context window limit ({context_window_limit}) is greater than maximum of ({matched_llm_config.context_window})",
|
||||
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})",
|
||||
argument_name="context_window_limit",
|
||||
)
|
||||
matched_llm_config.context_window = context_window_limit
|
||||
llm_config.context_window = context_window_limit
|
||||
else:
|
||||
matched_llm_config.context_window = min(matched_llm_config.context_window, model_settings.global_max_context_window_limit)
|
||||
llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit)
|
||||
|
||||
if max_tokens is not None:
|
||||
matched_llm_config.max_tokens = max_tokens
|
||||
llm_config.max_tokens = max_tokens
|
||||
if max_reasoning_tokens is not None:
|
||||
if not max_tokens or max_reasoning_tokens > max_tokens:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})",
|
||||
argument_name="max_reasoning_tokens",
|
||||
)
|
||||
matched_llm_config.max_reasoning_tokens = max_reasoning_tokens
|
||||
llm_config.max_reasoning_tokens = max_reasoning_tokens
|
||||
if enable_reasoner is not None:
|
||||
matched_llm_config.enable_reasoner = enable_reasoner
|
||||
if enable_reasoner and matched_llm_config.model_endpoint_type == "anthropic":
|
||||
matched_llm_config.put_inner_thoughts_in_kwargs = False
|
||||
llm_config.enable_reasoner = enable_reasoner
|
||||
if enable_reasoner and llm_config.model_endpoint_type == "anthropic":
|
||||
llm_config.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
return matched_llm_config
|
||||
return llm_config
|
||||
|
||||
@trace_method
|
||||
async def get_embedding_config_from_handle_async(
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
matched_embedding_config = None
|
||||
available_handles = []
|
||||
# Get all enabled providers (including BYOK providers from database)
|
||||
providers = await self.get_enabled_providers_async(actor=actor)
|
||||
for provider in providers:
|
||||
embedding_configs = await self.list_embedding_models_async(actor=actor)
|
||||
for embedding_config in embedding_configs:
|
||||
available_handles.append(embedding_config.handle)
|
||||
if embedding_config.handle == handle:
|
||||
matched_embedding_config = embedding_config
|
||||
break
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_embedding_configs = await provider.list_embedding_models_async()
|
||||
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
except LettaInvalidArgumentError as e:
|
||||
# search local configs
|
||||
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise e
|
||||
|
||||
if len(embedding_configs) == 1:
|
||||
embedding_config = embedding_configs[0]
|
||||
elif len(embedding_configs) > 1:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
else:
|
||||
embedding_config = embedding_configs[0]
|
||||
|
||||
if embedding_chunk_size:
|
||||
matched_embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
return matched_embedding_config
|
||||
return embedding_config
|
||||
|
||||
async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider:
|
||||
all_providers = await self.get_enabled_providers_async(actor)
|
||||
|
||||
Reference in New Issue
Block a user