diff --git a/letta/groups/dynamic_multi_agent.py b/letta/groups/dynamic_multi_agent.py index f89a6f3a..500d923d 100644 --- a/letta/groups/dynamic_multi_agent.py +++ b/letta/groups/dynamic_multi_agent.py @@ -94,6 +94,7 @@ class DynamicMultiAgent(Agent): for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]: if name.lower() in assistant_message.content.lower(): speaker_id = agent_id + assert speaker_id is not None, f"No names found in {assistant_message.content}" # Sum usage total_usage.prompt_tokens += usage_stats.prompt_tokens diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 2fb9b865..8b8445a2 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -717,7 +717,7 @@ def _prepare_anthropic_request( data["temperature"] = 1.0 if "functions" in data: - raise ValueError(f"'functions' unexpected in Anthropic API payload") + raise ValueError("'functions' unexpected in Anthropic API payload") # Handle tools if "tools" in data and data["tools"] is None: @@ -1150,7 +1150,7 @@ def anthropic_chat_completions_process_stream( accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments if message_delta.function_call is not None: - raise NotImplementedError(f"Old function_call style not support with stream=True") + raise NotImplementedError("Old function_call style not support with stream=True") # overwrite response fields based on latest chunk if not create_message_id: diff --git a/letta/llm_api/aws_bedrock.py b/letta/llm_api/aws_bedrock.py index 539ce0fd..fd994795 100644 --- a/letta/llm_api/aws_bedrock.py +++ b/letta/llm_api/aws_bedrock.py @@ -1,17 +1,30 @@ +""" +Note that this formally only supports Anthropic Bedrock. +TODO (cliandy): determine what other providers are supported and what is needed to add support. +""" + import os -from typing import Any, Dict, List, Optional +from typing import Any, Optional from anthropic import AnthropicBedrock +from letta.log import get_logger from letta.settings import model_settings +logger = get_logger(__name__) + def has_valid_aws_credentials() -> bool: """ Check if AWS credentials are properly configured. """ - valid_aws_credentials = os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_DEFAULT_REGION") - return valid_aws_credentials + return all( + ( + os.getenv("AWS_ACCESS_KEY_ID"), + os.getenv("AWS_SECRET_ACCESS_KEY"), + os.getenv("AWS_DEFAULT_REGION"), + ) + ) def get_bedrock_client( @@ -41,48 +54,11 @@ def get_bedrock_client( return bedrock -def bedrock_get_model_list( - region_name: str, - access_key_id: Optional[str] = None, - secret_access_key: Optional[str] = None, -) -> List[dict]: - """ - Get list of available models from Bedrock. - - Args: - region_name: AWS region name - access_key_id: Optional AWS access key ID - secret_access_key: Optional AWS secret access key - - TODO: Implement model_provider and output_modality filtering - model_provider: Optional provider name to filter models. If None, returns all models. - output_modality: Output modality to filter models. Defaults to "text". - - Returns: - List of model summaries - - """ - import boto3 - - try: - bedrock = boto3.client( - "bedrock", - region_name=region_name, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - ) - response = bedrock.list_inference_profiles() - return response["inferenceProfileSummaries"] - except Exception as e: - print(f"Error getting model list: {str(e)}") - raise e - - async def bedrock_get_model_list_async( access_key_id: Optional[str] = None, secret_access_key: Optional[str] = None, default_region: Optional[str] = None, -) -> List[dict]: +) -> list[dict]: from aioboto3.session import Session try: @@ -96,11 +72,11 @@ async def bedrock_get_model_list_async( response = await bedrock.list_inference_profiles() return response["inferenceProfileSummaries"] except Exception as e: - print(f"Error getting model list: {str(e)}") + logger.error(f"Error getting model list for bedrock: %s", e) raise e -def bedrock_get_model_details(region_name: str, model_id: str) -> Dict[str, Any]: +def bedrock_get_model_details(region_name: str, model_id: str) -> dict[str, Any]: """ Get details for a specific model from Bedrock. """ @@ -121,54 +97,8 @@ def bedrock_get_model_context_window(model_id: str) -> int: Get context window size for a specific model. """ # Bedrock doesn't provide this via API, so we maintain a mapping - context_windows = { - "anthropic.claude-3-5-sonnet-20241022-v2:0": 200000, - "anthropic.claude-3-5-sonnet-20240620-v1:0": 200000, - "anthropic.claude-3-5-haiku-20241022-v1:0": 200000, - "anthropic.claude-3-haiku-20240307-v1:0": 200000, - "anthropic.claude-3-opus-20240229-v1:0": 200000, - "anthropic.claude-3-sonnet-20240229-v1:0": 200000, - } - return context_windows.get(model_id, 200000) # default to 100k if unknown - - -""" -{ - "id": "msg_123", - "type": "message", - "role": "assistant", - "model": "anthropic.claude-3-5-sonnet-20241022-v2:0", - "content": [ - { - "type": "text", - "text": "I see the Firefox icon. Let me click on it and then navigate to a weather website." - }, - { - "type": "tool_use", - "id": "toolu_123", - "name": "computer", - "input": { - "action": "mouse_move", - "coordinate": [ - 708, - 736 - ] - } - }, - { - "type": "tool_use", - "id": "toolu_234", - "name": "computer", - "input": { - "action": "left_click" - } - } - ], - "stop_reason": "tool_use", - "stop_sequence": null, - "usage": { - "input_tokens": 3391, - "output_tokens": 132 - } -} -""" + # 200k for anthropic: https://aws.amazon.com/bedrock/anthropic/ + if model_id.startswith("anthropic"): + return 200_000 + else: + return 100_000 # default to 100k if unknown diff --git a/letta/llm_api/deepseek.py b/letta/llm_api/deepseek.py index f0b2a45a..5d4eb9e1 100644 --- a/letta/llm_api/deepseek.py +++ b/letta/llm_api/deepseek.py @@ -120,7 +120,7 @@ def build_deepseek_chat_completions_request( def add_functions_to_system_message(system_message: ChatMessage): system_message.content += f" {''.join(json.dumps(f) for f in functions)} " - system_message.content += f'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.' + system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.' if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively add_functions_to_system_message( diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index a8fa03f4..e7987aa2 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -66,44 +66,6 @@ def google_ai_check_valid_api_key(api_key: str): raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR) -def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]: - """Synchronous version to get model list from Google AI API using httpx.""" - import httpx - - from letta.utils import printd - - url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) - - try: - with httpx.Client() as client: - response = client.get(url, headers=headers) - response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status - response_data = response.json() # convert to dict from string - - # Grab the models out - model_list = response_data["models"] - return model_list - - except httpx.HTTPStatusError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}") - # Print the HTTP status code - print(f"HTTP Error: {http_err.response.status_code}") - # Print the response content (error message from server) - print(f"Message: {http_err.response.text}") - raise http_err - - except httpx.RequestError as req_err: - # Handle other httpx-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e - - async def google_ai_get_model_list_async( base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None ) -> List[dict]: diff --git a/letta/llm_api/google_constants.py b/letta/llm_api/google_constants.py index 1c30d615..3a50fb1b 100644 --- a/letta/llm_api/google_constants.py +++ b/letta/llm_api/google_constants.py @@ -1,7 +1,12 @@ GOOGLE_MODEL_TO_CONTEXT_LENGTH = { + "gemini-2.5-pro": 1048576, + "gemini-2.5-flash": 1048576, + "gemini-live-2.5-flash": 1048576, + "gemini-2.0-flash-001": 1048576, + "gemini-2.0-flash-lite-001": 1048576, + # The following are either deprecated or discontinued. "gemini-2.5-pro-exp-03-25": 1048576, "gemini-2.5-flash-preview-04-17": 1048576, - "gemini-2.0-flash-001": 1048576, "gemini-2.0-pro-exp-02-05": 2097152, "gemini-2.0-flash-lite-preview-02-05": 1048576, "gemini-2.0-flash-thinking-exp-01-21": 1048576, @@ -11,8 +16,6 @@ GOOGLE_MODEL_TO_CONTEXT_LENGTH = { "gemini-1.0-pro-vision": 16384, } -GOOGLE_MODEL_TO_OUTPUT_LENGTH = {"gemini-2.0-flash-001": 8192, "gemini-2.5-pro-exp-03-25": 65536} - GOOGLE_EMBEDING_MODEL_TO_DIM = {"text-embedding-005": 768, "text-multilingual-embedding-002": 768} GOOGLE_MODEL_FOR_API_KEY_CHECK = "gemini-2.0-flash-lite" diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 749ae974..cff7b96a 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -252,7 +252,7 @@ def unpack_all_inner_thoughts_from_kwargs( ) -> ChatCompletionResponse: """Strip the inner thoughts out of the tool call and put it in the message content""" if len(response.choices) == 0: - raise ValueError(f"Unpacking inner thoughts from empty response not supported") + raise ValueError("Unpacking inner thoughts from empty response not supported") new_choices = [] for choice in response.choices: diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 0f46a17a..bf943ed4 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -67,7 +67,6 @@ def retry_with_exponential_backoff( # Stop retrying if user hits Ctrl-C raise KeyboardInterrupt("User intentionally stopped thread. Stopping...") except requests.exceptions.HTTPError as http_err: - if not hasattr(http_err, "response") or not http_err.response: raise @@ -175,7 +174,6 @@ def create( # openai if llm_config.model_endpoint_type == "openai": - if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) @@ -256,7 +254,6 @@ def create( return response elif llm_config.model_endpoint_type == "xai": - api_key = model_settings.xai_api_key if function_call is None and functions is not None and len(functions) > 0: @@ -464,7 +461,7 @@ def create( # ) elif llm_config.model_endpoint_type == "groq": if stream: - raise NotImplementedError(f"Streaming not yet implemented for Groq.") + raise NotImplementedError("Streaming not yet implemented for Groq.") if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions": raise LettaConfigurationError(message="Groq key is missing from letta config file", missing_fields=["groq_api_key"]) @@ -517,7 +514,7 @@ def create( """TogetherAI endpoint that goes via /completions instead of /chat/completions""" if stream: - raise NotImplementedError(f"Streaming not yet implemented for TogetherAI (via the /completions endpoint).") + raise NotImplementedError("Streaming not yet implemented for TogetherAI (via the /completions endpoint).") if model_settings.together_api_key is None and ( llm_config.model_endpoint == "https://api.together.ai/v1/completions" @@ -547,7 +544,7 @@ def create( """Anthropic endpoint that goes via /embeddings instead of /chat/completions""" if stream: - raise NotImplementedError(f"Streaming not yet implemented for Anthropic (via the /embeddings endpoint).") + raise NotImplementedError("Streaming not yet implemented for Anthropic (via the /embeddings endpoint).") if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") @@ -631,7 +628,7 @@ def create( messages[0].content[0].text += f" {''.join(json.dumps(f) for f in functions)} " messages[0].content[ 0 - ].text += f'Select best function to call simply by responding with a single json block with the keys "function" and "params". Use double quotes around the arguments.' + ].text += 'Select best function to call simply by responding with a single json block with the keys "function" and "params". Use double quotes around the arguments.' return get_chat_completion( model=llm_config.model, messages=messages, diff --git a/letta/llm_api/mistral.py b/letta/llm_api/mistral.py index 932cf874..68bb87e9 100644 --- a/letta/llm_api/mistral.py +++ b/letta/llm_api/mistral.py @@ -1,47 +1,22 @@ -import requests +import aiohttp -from letta.utils import printd, smart_urljoin +from letta.log import get_logger +from letta.utils import smart_urljoin + +logger = get_logger(__name__) -def mistral_get_model_list(url: str, api_key: str) -> dict: +async def mistral_get_model_list_async(url: str, api_key: str) -> dict: url = smart_urljoin(url, "models") headers = {"Content-Type": "application/json"} if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" - printd(f"Sending request to {url}") - response = None - try: + logger.debug(f"Sending request to %s", url) + + async with aiohttp.ClientSession() as session: # TODO add query param "tool" to be true - response = requests.get(url, headers=headers) - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response_json = response.json() # convert to dict from string - return response_json - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - try: - if response: - response = response.json() - except: - pass - printd(f"Got HTTPError, exception={http_err}, response={response}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - try: - if response: - response = response.json() - except: - pass - printd(f"Got RequestException, exception={req_err}, response={response}") - raise req_err - except Exception as e: - # Handle other potential errors - try: - if response: - response = response.json() - except: - pass - printd(f"Got unknown Exception, exception={e}, response={response}") - raise e + async with session.get(url, headers=headers) as response: + response.raise_for_status() + return await response.json() diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 7372b1d0..b83c4de4 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -59,11 +59,15 @@ def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool = False, extra_params: Optional[dict] = None) -> dict: """https://platform.openai.com/docs/api-reference/models/list""" - from letta.utils import printd # In some cases we may want to double-check the URL and do basic correction, eg: # In Letta config the address for vLLM is w/o a /v1 suffix for simplicity # However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit + + import warnings + + warnings.warn("The synchronous version of openai_get_model_list function is deprecated. Use the async one instead.", DeprecationWarning) + if fix_url: if not url.endswith("/v1"): url = smart_urljoin(url, "v1") @@ -74,14 +78,14 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" - printd(f"Sending request to {url}") + logger.debug(f"Sending request to {url}") response = None try: # TODO add query param "tool" to be true response = requests.get(url, headers=headers, params=extra_params) response.raise_for_status() # Raises HTTPError for 4XX/5XX status response = response.json() # convert to dict from string - printd(f"response = {response}") + logger.debug(f"response = {response}") return response except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) @@ -90,7 +94,7 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool response = response.json() except: pass - printd(f"Got HTTPError, exception={http_err}, response={response}") + logger.debug(f"Got HTTPError, exception={http_err}, response={response}") raise http_err except requests.exceptions.RequestException as req_err: # Handle other requests-related errors (e.g., connection error) @@ -99,7 +103,7 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool response = response.json() except: pass - printd(f"Got RequestException, exception={req_err}, response={response}") + logger.debug(f"Got RequestException, exception={req_err}, response={response}") raise req_err except Exception as e: # Handle other potential errors @@ -108,7 +112,7 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool response = response.json() except: pass - printd(f"Got unknown Exception, exception={e}, response={response}") + logger.debug(f"Got unknown Exception, exception={e}, response={response}") raise e @@ -120,7 +124,6 @@ async def openai_get_model_list_async( client: Optional["httpx.AsyncClient"] = None, ) -> dict: """https://platform.openai.com/docs/api-reference/models/list""" - from letta.utils import printd # In some cases we may want to double-check the URL and do basic correction if fix_url and not url.endswith("/v1"): @@ -132,7 +135,7 @@ async def openai_get_model_list_async( if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" - printd(f"Sending request to {url}") + logger.debug(f"Sending request to {url}") # Use provided client or create a new one close_client = False @@ -144,24 +147,23 @@ async def openai_get_model_list_async( response = await client.get(url, headers=headers, params=extra_params) response.raise_for_status() result = response.json() - printd(f"response = {result}") + logger.debug(f"response = {result}") return result except httpx.HTTPStatusError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) - error_response = None try: error_response = http_err.response.json() except: error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text} - printd(f"Got HTTPError, exception={http_err}, response={error_response}") + logger.debug(f"Got HTTPError, exception={http_err}, response={error_response}") raise http_err except httpx.RequestError as req_err: # Handle other httpx-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") + logger.debug(f"Got RequestException, exception={req_err}") raise req_err except Exception as e: # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") + logger.debug(f"Got unknown Exception, exception={e}") raise e finally: if close_client: @@ -480,7 +482,7 @@ def openai_chat_completions_process_stream( ) if message_delta.function_call is not None: - raise NotImplementedError(f"Old function_call style not support with stream=True") + raise NotImplementedError("Old function_call style not support with stream=True") # overwrite response fields based on latest chunk if not create_message_id: @@ -503,7 +505,7 @@ def openai_chat_completions_process_stream( logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") raise e finally: - logger.info(f"Finally ending streaming interface.") + logger.info("Finally ending streaming interface.") if stream_interface: stream_interface.stream_end() @@ -525,7 +527,6 @@ def openai_chat_completions_process_stream( assert len(chat_completion_response.choices) > 0, f"No response from provider {chat_completion_response}" - # printd(chat_completion_response) log_event(name="llm_response_received", attributes=chat_completion_response.model_dump()) return chat_completion_response @@ -536,7 +537,6 @@ def openai_chat_completions_request_stream( chat_completion_request: ChatCompletionRequest, fix_url: bool = False, ) -> Generator[ChatCompletionChunkResponse, None, None]: - # In some cases we may want to double-check the URL and do basic correction, eg: # In Letta config the address for vLLM is w/o a /v1 suffix for simplicity # However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit diff --git a/letta/llm_api/sample_response_jsons/aws_bedrock.json b/letta/llm_api/sample_response_jsons/aws_bedrock.json new file mode 100644 index 00000000..c8ff79c8 --- /dev/null +++ b/letta/llm_api/sample_response_jsons/aws_bedrock.json @@ -0,0 +1,38 @@ +{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "model": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "content": [ + { + "type": "text", + "text": "I see the Firefox icon. Let me click on it and then navigate to a weather website." + }, + { + "type": "tool_use", + "id": "toolu_123", + "name": "computer", + "input": { + "action": "mouse_move", + "coordinate": [ + 708, + 736 + ] + } + }, + { + "type": "tool_use", + "id": "toolu_234", + "name": "computer", + "input": { + "action": "left_click" + } + } + ], + "stop_reason": "tool_use", + "stop_sequence": null, + "usage": { + "input_tokens": 3391, + "output_tokens": 132 + } +} diff --git a/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json b/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json new file mode 100644 index 00000000..25489ff3 --- /dev/null +++ b/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json @@ -0,0 +1,15 @@ +{ + "object": "list", + "data": [ + { + "id": "text-embedding-nomic-embed-text-v1.5", + "object": "model", + "type": "embeddings", + "publisher": "nomic-ai", + "arch": "nomic-bert", + "compatibility_type": "gguf", + "quantization": "Q4_0", + "state": "not-loaded", + "max_context_length": 2048 + }, + ... diff --git a/letta/llm_api/sample_response_jsons/lmstudio_model_list.json b/letta/llm_api/sample_response_jsons/lmstudio_model_list.json new file mode 100644 index 00000000..8b7e7b70 --- /dev/null +++ b/letta/llm_api/sample_response_jsons/lmstudio_model_list.json @@ -0,0 +1,15 @@ + { + "object": "list", + "data": [ + { + "id": "qwen2-vl-7b-instruct", + "object": "model", + "type": "vlm", + "publisher": "mlx-community", + "arch": "qwen2_vl", + "compatibility_type": "mlx", + "quantization": "4bit", + "state": "not-loaded", + "max_context_length": 32768 + }, + ..., diff --git a/letta/local_llm/constants.py b/letta/local_llm/constants.py index 83681fde..2b51101d 100644 --- a/letta/local_llm/constants.py +++ b/letta/local_llm/constants.py @@ -1,27 +1,5 @@ -# import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper -DEFAULT_ENDPOINTS = { - # Local - "koboldcpp": "http://localhost:5001", - "llamacpp": "http://localhost:8080", - "lmstudio": "http://localhost:1234", - "lmstudio-legacy": "http://localhost:1234", - "ollama": "http://localhost:11434", - "webui-legacy": "http://localhost:5000", - "webui": "http://localhost:5000", - "vllm": "http://localhost:8000", - # APIs - "openai": "https://api.openai.com", - "anthropic": "https://api.anthropic.com", - "groq": "https://api.groq.com/openai", -} - -DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K" - -# DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper -# DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1" - DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper DEFAULT_WRAPPER_NAME = "chatml" diff --git a/letta/local_llm/llm_chat_completion_wrappers/airoboros.py b/letta/local_llm/llm_chat_completion_wrappers/airoboros.py index 5f076ec8..544d11d4 100644 --- a/letta/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/letta/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -75,7 +75,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): func_str = "" func_str += f"{schema['name']}:" func_str += f"\n description: {schema['description']}" - func_str += f"\n params:" + func_str += "\n params:" for param_k, param_v in schema["parameters"]["properties"].items(): # TODO we're ignoring type func_str += f"\n {param_k}: {param_v['description']}" @@ -83,8 +83,8 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): return func_str # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." - prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:" + prompt += "\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:" if function_documentation is not None: prompt += f"\n{function_documentation}" else: @@ -150,7 +150,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): prompt += "\n### RESPONSE" if self.include_assistant_prefix: - prompt += f"\nASSISTANT:" + prompt += "\nASSISTANT:" if self.include_opening_brance_in_prefix: prompt += "\n{" @@ -282,9 +282,9 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): func_str = "" func_str += f"{schema['name']}:" func_str += f"\n description: {schema['description']}" - func_str += f"\n params:" + func_str += "\n params:" if add_inner_thoughts: - func_str += f"\n inner_thoughts: Deep inner monologue private to you only." + func_str += "\n inner_thoughts: Deep inner monologue private to you only." for param_k, param_v in schema["parameters"]["properties"].items(): # TODO we're ignoring type func_str += f"\n {param_k}: {param_v['description']}" @@ -292,8 +292,8 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): return func_str # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." - prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:" + prompt += "\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:" if function_documentation is not None: prompt += f"\n{function_documentation}" else: @@ -375,7 +375,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): prompt += "\n### RESPONSE" if self.include_assistant_prefix: - prompt += f"\nASSISTANT:" + prompt += "\nASSISTANT:" if self.assistant_prefix_extra: prompt += self.assistant_prefix_extra diff --git a/letta/local_llm/llm_chat_completion_wrappers/chatml.py b/letta/local_llm/llm_chat_completion_wrappers/chatml.py index 62e8a8cf..71589959 100644 --- a/letta/local_llm/llm_chat_completion_wrappers/chatml.py +++ b/letta/local_llm/llm_chat_completion_wrappers/chatml.py @@ -71,7 +71,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): func_str = "" func_str += f"{schema['name']}:" func_str += f"\n description: {schema['description']}" - func_str += f"\n params:" + func_str += "\n params:" if add_inner_thoughts: from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION @@ -87,8 +87,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): prompt = "" # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." - prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:" + prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:" for function_dict in functions: prompt += f"\n{self._compile_function_description(function_dict)}" @@ -101,8 +101,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): prompt += system_message prompt += "\n" if function_documentation is not None: - prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:\n" + prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:\n" prompt += function_documentation else: prompt += self._compile_function_block(functions) @@ -230,7 +230,6 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): prompt += f"\n<|im_start|>{role_str}\n{msg_str.strip()}<|im_end|>" elif message["role"] == "system": - role_str = "system" msg_str = self._compile_system_message( system_message=message["content"], functions=functions, function_documentation=function_documentation @@ -255,7 +254,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): raise ValueError(message) if self.include_assistant_prefix: - prompt += f"\n<|im_start|>assistant" + prompt += "\n<|im_start|>assistant" if self.assistant_prefix_hint: prompt += f"\n{FIRST_PREFIX_HINT if first_message else PREFIX_HINT}" if self.supports_first_message and first_message: @@ -386,7 +385,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper): "You must always include inner thoughts, but you do not always have to call a function.", ] ) - prompt += f"\nAvailable functions:" + prompt += "\nAvailable functions:" for function_dict in functions: prompt += f"\n{self._compile_function_description(function_dict, add_inner_thoughts=False)}" diff --git a/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py b/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py index fa0fa839..9f53fa83 100644 --- a/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +++ b/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py @@ -91,9 +91,9 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): func_str = "" func_str += f"{schema['name']}:" func_str += f"\n description: {schema['description']}" - func_str += f"\n params:" + func_str += "\n params:" if add_inner_thoughts: - func_str += f"\n inner_thoughts: Deep inner monologue private to you only." + func_str += "\n inner_thoughts: Deep inner monologue private to you only." for param_k, param_v in schema["parameters"]["properties"].items(): # TODO we're ignoring type func_str += f"\n {param_k}: {param_v['description']}" @@ -105,8 +105,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): prompt = "" # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." - prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:" + prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:" for function_dict in functions: prompt += f"\n{self._compile_function_description(function_dict)}" @@ -117,8 +117,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): prompt = system_message prompt += "\n" if function_documentation is not None: - prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:" + prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:" prompt += function_documentation else: prompt += self._compile_function_block(functions) diff --git a/letta/local_llm/llm_chat_completion_wrappers/dolphin.py b/letta/local_llm/llm_chat_completion_wrappers/dolphin.py index 6a7f0852..e393d9b1 100644 --- a/letta/local_llm/llm_chat_completion_wrappers/dolphin.py +++ b/letta/local_llm/llm_chat_completion_wrappers/dolphin.py @@ -85,7 +85,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): func_str = "" func_str += f"{schema['name']}:" func_str += f"\n description: {schema['description']}" - func_str += f"\n params:" + func_str += "\n params:" for param_k, param_v in schema["parameters"]["properties"].items(): # TODO we're ignoring type func_str += f"\n {param_k}: {param_v['description']}" @@ -93,8 +93,8 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): return func_str # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." - prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." - prompt += f"\nAvailable functions:" + prompt += "\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += "\nAvailable functions:" if function_documentation is not None: prompt += f"\n{function_documentation}" else: diff --git a/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py b/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py index c69f0960..d20bd2d3 100644 --- a/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +++ b/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py @@ -124,7 +124,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper): if self.include_assistant_prefix: # prompt += f"\nASSISTANT:" - prompt += f"\nSUMMARY:" + prompt += "\nSUMMARY:" # print(prompt) return prompt diff --git a/letta/local_llm/ollama/api.py b/letta/local_llm/ollama/api.py index 00bdf509..69926a43 100644 --- a/letta/local_llm/ollama/api.py +++ b/letta/local_llm/ollama/api.py @@ -18,7 +18,7 @@ def get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_ if model is None: raise LocalLLMError( - f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')" + "Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')" ) # Settings for the generation, includes the prompt + stop tokens, max length, etc @@ -51,7 +51,7 @@ def get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_ # Set grammar if grammar is not None: # request["grammar_string"] = load_grammar_file(grammar) - raise NotImplementedError(f"Ollama does not support grammars") + raise NotImplementedError("Ollama does not support grammars") if not endpoint.startswith(("http://", "https://")): raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") diff --git a/letta/schemas/embedding_config.py b/letta/schemas/embedding_config.py index 38e6542d..1020382b 100644 --- a/letta/schemas/embedding_config.py +++ b/letta/schemas/embedding_config.py @@ -2,6 +2,8 @@ from typing import Literal, Optional from pydantic import BaseModel, Field +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE + class EmbeddingConfig(BaseModel): """ @@ -62,7 +64,7 @@ class EmbeddingConfig(BaseModel): embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, - embedding_chunk_size=300, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, ) if (model_name == "text-embedding-3-small" and provider == "openai") or (not model_name and provider == "openai"): return cls( @@ -70,14 +72,14 @@ class EmbeddingConfig(BaseModel): embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=2000, - embedding_chunk_size=300, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, ) elif model_name == "letta": return cls( embedding_endpoint="https://embeddings.memgpt.ai", embedding_model="BAAI/bge-large-en-v1.5", embedding_dim=1024, - embedding_chunk_size=300, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, embedding_endpoint_type="hugging-face", ) else: diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index cbe922ca..a76178f9 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -8,6 +8,7 @@ class ProviderType(str, Enum): openai = "openai" letta = "letta" deepseek = "deepseek" + cerebras = "cerebras" lmstudio_openai = "lmstudio_openai" xai = "xai" mistral = "mistral" @@ -17,6 +18,7 @@ class ProviderType(str, Enum): azure = "azure" vllm = "vllm" bedrock = "bedrock" + cohere = "cohere" class ProviderCategory(str, Enum): diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py deleted file mode 100644 index 97d68281..00000000 --- a/letta/schemas/providers.py +++ /dev/null @@ -1,1618 +0,0 @@ -import warnings -from datetime import datetime -from typing import List, Literal, Optional - -import aiohttp -import requests -from pydantic import BaseModel, Field, model_validator - -from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW -from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint -from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES -from letta.schemas.enums import ProviderCategory, ProviderType -from letta.schemas.letta_base import LettaBase -from letta.schemas.llm_config import LLMConfig -from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES -from letta.settings import model_settings - - -class ProviderBase(LettaBase): - __id_prefix__ = "provider" - - -class Provider(ProviderBase): - id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") - name: str = Field(..., description="The name of the provider") - provider_type: ProviderType = Field(..., description="The type of the provider") - provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)") - api_key: Optional[str] = Field(None, description="API key or secret key used for requests to the provider.") - base_url: Optional[str] = Field(None, description="Base URL for the provider.") - access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") - region: Optional[str] = Field(None, description="Region used for requests to the provider.") - organization_id: Optional[str] = Field(None, description="The organization id of the user") - updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") - - @model_validator(mode="after") - def default_base_url(self): - if self.provider_type == ProviderType.openai and self.base_url is None: - self.base_url = model_settings.openai_api_base - return self - - def resolve_identifier(self): - if not self.id: - self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) - - def check_api_key(self): - """Check if the API key is valid for the provider""" - raise NotImplementedError - - def list_llm_models(self) -> List[LLMConfig]: - return [] - - async def list_llm_models_async(self) -> List[LLMConfig]: - return [] - - def list_embedding_models(self) -> List[EmbeddingConfig]: - return [] - - async def list_embedding_models_async(self) -> List[EmbeddingConfig]: - return self.list_embedding_models() - - def get_model_context_window(self, model_name: str) -> Optional[int]: - raise NotImplementedError - - async def get_model_context_window_async(self, model_name: str) -> Optional[int]: - raise NotImplementedError - - def provider_tag(self) -> str: - """String representation of the provider for display purposes""" - raise NotImplementedError - - def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: - """ - Get the handle for a model, with support for custom overrides. - - Args: - model_name (str): The name of the model. - is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False. - - Returns: - str: The handle for the model. - """ - base_name = base_name if base_name else self.name - - overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES - if base_name in overrides and model_name in overrides[base_name]: - model_name = overrides[base_name][model_name] - - return f"{base_name}/{model_name}" - - def cast_to_subtype(self): - match self.provider_type: - case ProviderType.letta: - return LettaProvider(**self.model_dump(exclude_none=True)) - case ProviderType.openai: - return OpenAIProvider(**self.model_dump(exclude_none=True)) - case ProviderType.anthropic: - return AnthropicProvider(**self.model_dump(exclude_none=True)) - case ProviderType.bedrock: - return BedrockProvider(**self.model_dump(exclude_none=True)) - case ProviderType.ollama: - return OllamaProvider(**self.model_dump(exclude_none=True)) - case ProviderType.google_ai: - return GoogleAIProvider(**self.model_dump(exclude_none=True)) - case ProviderType.google_vertex: - return GoogleVertexProvider(**self.model_dump(exclude_none=True)) - case ProviderType.azure: - return AzureProvider(**self.model_dump(exclude_none=True)) - case ProviderType.groq: - return GroqProvider(**self.model_dump(exclude_none=True)) - case ProviderType.together: - return TogetherProvider(**self.model_dump(exclude_none=True)) - case ProviderType.vllm_chat_completions: - return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True)) - case ProviderType.vllm_completions: - return VLLMCompletionsProvider(**self.model_dump(exclude_none=True)) - case ProviderType.xai: - return XAIProvider(**self.model_dump(exclude_none=True)) - case _: - raise ValueError(f"Unknown provider type: {self.provider_type}") - - -class ProviderCreate(ProviderBase): - name: str = Field(..., description="The name of the provider.") - provider_type: ProviderType = Field(..., description="The type of the provider.") - api_key: str = Field(..., description="API key or secret key used for requests to the provider.") - access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") - region: Optional[str] = Field(None, description="Region used for requests to the provider.") - - -class ProviderUpdate(ProviderBase): - api_key: str = Field(..., description="API key or secret key used for requests to the provider.") - access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") - region: Optional[str] = Field(None, description="Region used for requests to the provider.") - - -class ProviderCheck(BaseModel): - provider_type: ProviderType = Field(..., description="The type of the provider.") - api_key: str = Field(..., description="API key or secret key used for requests to the provider.") - access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") - region: Optional[str] = Field(None, description="Region used for requests to the provider.") - - -class LettaProvider(Provider): - provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - - def list_llm_models(self) -> List[LLMConfig]: - return [ - LLMConfig( - model="letta-free", # NOTE: renamed - model_endpoint_type="openai", - model_endpoint=LETTA_MODEL_ENDPOINT, - context_window=30000, - handle=self.get_handle("letta-free"), - provider_name=self.name, - provider_category=self.provider_category, - ) - ] - - async def list_llm_models_async(self) -> List[LLMConfig]: - return [ - LLMConfig( - model="letta-free", # NOTE: renamed - model_endpoint_type="openai", - model_endpoint=LETTA_MODEL_ENDPOINT, - context_window=30000, - handle=self.get_handle("letta-free"), - provider_name=self.name, - provider_category=self.provider_category, - ) - ] - - def list_embedding_models(self): - return [ - EmbeddingConfig( - embedding_model="letta-free", # NOTE: renamed - embedding_endpoint_type="hugging-face", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_dim=1024, - embedding_chunk_size=300, - handle=self.get_handle("letta-free", is_embedding=True), - batch_size=32, - ) - ] - - -class OpenAIProvider(Provider): - provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the OpenAI API.") - base_url: str = Field(..., description="Base URL for the OpenAI API.") - - def check_api_key(self): - from letta.llm_api.openai import openai_check_valid_api_key - - openai_check_valid_api_key(self.base_url, self.api_key) - - def _get_models(self) -> List[dict]: - from letta.llm_api.openai import openai_get_model_list - - # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... - # See: https://openrouter.ai/docs/requests - extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None - - # Similar to Nebius - extra_params = {"verbose": True} if "nebius.com" in self.base_url else None - - response = openai_get_model_list( - self.base_url, - api_key=self.api_key, - extra_params=extra_params, - # fix_url=True, # NOTE: make sure together ends with /v1 - ) - - if "data" in response: - data = response["data"] - else: - # TogetherAI's response is missing the 'data' field - data = response - - return data - - async def _get_models_async(self) -> List[dict]: - from letta.llm_api.openai import openai_get_model_list_async - - # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... - # See: https://openrouter.ai/docs/requests - extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None - - # Similar to Nebius - extra_params = {"verbose": True} if "nebius.com" in self.base_url else None - - response = await openai_get_model_list_async( - self.base_url, - api_key=self.api_key, - extra_params=extra_params, - # fix_url=True, # NOTE: make sure together ends with /v1 - ) - - if "data" in response: - data = response["data"] - else: - # TogetherAI's response is missing the 'data' field - data = response - - return data - - def list_llm_models(self) -> List[LLMConfig]: - data = self._get_models() - return self._list_llm_models(data) - - async def list_llm_models_async(self) -> List[LLMConfig]: - data = await self._get_models_async() - return self._list_llm_models(data) - - def _list_llm_models(self, data) -> List[LLMConfig]: - configs = [] - for model in data: - assert "id" in model, f"OpenAI model missing 'id' field: {model}" - model_name = model["id"] - - if "context_length" in model: - # Context length is returned in OpenRouter as "context_length" - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) - - if not context_window_size: - continue - - # TogetherAI includes the type, which we can use to filter out embedding models - if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url: - if "type" in model and model["type"] not in ["chat", "language"]: - continue - - # for TogetherAI, we need to skip the models that don't support JSON mode / function calling - # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: { - # "error": { - # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling", - # "type": "invalid_request_error", - # "param": null, - # "code": "constraints_model" - # } - # } - if "config" not in model: - continue - - if "nebius.com" in self.base_url: - # Nebius includes the type, which we can use to filter for text models - try: - model_type = model["architecture"]["modality"] - if model_type not in ["text->text", "text+image->text"]: - # print(f"Skipping model w/ modality {model_type}:\n{model}") - continue - except KeyError: - print(f"Couldn't access architecture type field, skipping model:\n{model}") - continue - - # for openai, filter models - if self.base_url == "https://api.openai.com/v1": - allowed_types = ["gpt-4", "o1", "o3", "o4"] - # NOTE: o1-mini and o1-preview do not support tool calling - # NOTE: o1-mini does not support system messages - # NOTE: o1-pro is only available in Responses API - disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"] - skip = True - for model_type in allowed_types: - if model_name.startswith(model_type): - skip = False - break - for keyword in disallowed_types: - if keyword in model_name: - skip = True - break - # ignore this model - if skip: - continue - - # set the handle to openai-proxy if the base URL isn't OpenAI - if self.base_url != "https://api.openai.com/v1": - handle = self.get_handle(model_name, base_name="openai-proxy") - else: - handle = self.get_handle(model_name) - - llm_config = LLMConfig( - model=model_name, - model_endpoint_type="openai", - model_endpoint=self.base_url, - context_window=context_window_size, - handle=handle, - provider_name=self.name, - provider_category=self.provider_category, - ) - - # gpt-4o-mini has started to regress with pretty bad emoji spam loops - # this is to counteract that - if "gpt-4o-mini" in model_name: - llm_config.frequency_penalty = 1.0 - if "gpt-4.1-mini" in model_name: - llm_config.frequency_penalty = 1.0 - - configs.append(llm_config) - - # for OpenAI, sort in reverse order - if self.base_url == "https://api.openai.com/v1": - # alphnumeric sort - configs.sort(key=lambda x: x.model, reverse=True) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - if self.base_url == "https://api.openai.com/v1": - # TODO: actually automatically list models for OpenAI - return [ - EmbeddingConfig( - embedding_model="text-embedding-ada-002", - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=1536, - embedding_chunk_size=300, - handle=self.get_handle("text-embedding-ada-002", is_embedding=True), - batch_size=1024, - ), - EmbeddingConfig( - embedding_model="text-embedding-3-small", - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=2000, - embedding_chunk_size=300, - handle=self.get_handle("text-embedding-3-small", is_embedding=True), - batch_size=1024, - ), - EmbeddingConfig( - embedding_model="text-embedding-3-large", - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=2000, - embedding_chunk_size=300, - handle=self.get_handle("text-embedding-3-large", is_embedding=True), - batch_size=1024, - ), - ] - - else: - # Actually attempt to list - data = self._get_models() - return self._list_embedding_models(data) - - async def list_embedding_models_async(self) -> List[EmbeddingConfig]: - if self.base_url == "https://api.openai.com/v1": - # TODO: actually automatically list models for OpenAI - return [ - EmbeddingConfig( - embedding_model="text-embedding-ada-002", - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=1536, - embedding_chunk_size=300, - handle=self.get_handle("text-embedding-ada-002", is_embedding=True), - batch_size=1024, - ), - EmbeddingConfig( - embedding_model="text-embedding-3-small", - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=2000, - embedding_chunk_size=300, - handle=self.get_handle("text-embedding-3-small", is_embedding=True), - batch_size=1024, - ), - EmbeddingConfig( - embedding_model="text-embedding-3-large", - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=2000, - embedding_chunk_size=300, - handle=self.get_handle("text-embedding-3-large", is_embedding=True), - batch_size=1024, - ), - ] - - else: - # Actually attempt to list - data = await self._get_models_async() - return self._list_embedding_models(data) - - def _list_embedding_models(self, data) -> List[EmbeddingConfig]: - configs = [] - for model in data: - assert "id" in model, f"Model missing 'id' field: {model}" - model_name = model["id"] - - if "context_length" in model: - # Context length is returned in Nebius as "context_length" - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) - - # We need the context length for embeddings too - if not context_window_size: - continue - - if "nebius.com" in self.base_url: - # Nebius includes the type, which we can use to filter for embedidng models - try: - model_type = model["architecture"]["modality"] - if model_type not in ["text->embedding"]: - # print(f"Skipping model w/ modality {model_type}:\n{model}") - continue - except KeyError: - print(f"Couldn't access architecture type field, skipping model:\n{model}") - continue - - elif "together.ai" in self.base_url or "together.xyz" in self.base_url: - # TogetherAI includes the type, which we can use to filter for embedding models - if "type" in model and model["type"] not in ["embedding"]: - # print(f"Skipping model w/ modality {model_type}:\n{model}") - continue - - else: - # For other providers we should skip by default, since we don't want to assume embeddings are supported - continue - - configs.append( - EmbeddingConfig( - embedding_model=model_name, - embedding_endpoint_type=self.provider_type, - embedding_endpoint=self.base_url, - embedding_dim=context_window_size, - embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, - handle=self.get_handle(model, is_embedding=True), - ) - ) - - return configs - - def get_model_context_window_size(self, model_name: str): - if model_name in LLM_MAX_TOKENS: - return LLM_MAX_TOKENS[model_name] - else: - return LLM_MAX_TOKENS["DEFAULT"] - - -class DeepSeekProvider(OpenAIProvider): - """ - DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, - but with slight differences: - * For example, DeepSeek's API requires perfect interleaving of user/assistant - * It also does not support native function calling - """ - - provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") - api_key: str = Field(..., description="API key for the DeepSeek API.") - - def get_model_context_window_size(self, model_name: str) -> Optional[int]: - # DeepSeek doesn't return context window in the model listing, - # so these are hardcoded from their website - if model_name == "deepseek-reasoner": - return 64000 - elif model_name == "deepseek-chat": - return 64000 - else: - return None - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list - - response = openai_get_model_list(self.base_url, api_key=self.api_key) - - if "data" in response: - data = response["data"] - else: - data = response - - configs = [] - for model in data: - assert "id" in model, f"DeepSeek model missing 'id' field: {model}" - model_name = model["id"] - - # In case DeepSeek starts supporting it in the future: - if "context_length" in model: - # Context length is returned in OpenRouter as "context_length" - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) - - if not context_window_size: - warnings.warn(f"Couldn't find context window size for model {model_name}") - continue - - # Not used for deepseek-reasoner, but otherwise is true - put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True - - configs.append( - LLMConfig( - model=model_name, - model_endpoint_type="deepseek", - model_endpoint=self.base_url, - context_window=context_window_size, - handle=self.get_handle(model_name), - put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # No embeddings supported - return [] - - -class LMStudioOpenAIProvider(OpenAIProvider): - provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") - api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list - - # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' - MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" - response = openai_get_model_list(MODEL_ENDPOINT_URL) - - """ - Example response: - - { - "object": "list", - "data": [ - { - "id": "qwen2-vl-7b-instruct", - "object": "model", - "type": "vlm", - "publisher": "mlx-community", - "arch": "qwen2_vl", - "compatibility_type": "mlx", - "quantization": "4bit", - "state": "not-loaded", - "max_context_length": 32768 - }, - ... - """ - if "data" not in response: - warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") - return [] - - configs = [] - for model in response["data"]: - assert "id" in model, f"Model missing 'id' field: {model}" - model_name = model["id"] - - if "type" not in model: - warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") - continue - elif model["type"] not in ["vlm", "llm"]: - continue - - if "max_context_length" in model: - context_window_size = model["max_context_length"] - else: - warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") - continue - - configs.append( - LLMConfig( - model=model_name, - model_endpoint_type="openai", - model_endpoint=self.base_url, - context_window=context_window_size, - handle=self.get_handle(model_name), - ) - ) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - from letta.llm_api.openai import openai_get_model_list - - # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' - MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" - response = openai_get_model_list(MODEL_ENDPOINT_URL) - - """ - Example response: - { - "object": "list", - "data": [ - { - "id": "text-embedding-nomic-embed-text-v1.5", - "object": "model", - "type": "embeddings", - "publisher": "nomic-ai", - "arch": "nomic-bert", - "compatibility_type": "gguf", - "quantization": "Q4_0", - "state": "not-loaded", - "max_context_length": 2048 - } - ... - """ - if "data" not in response: - warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") - return [] - - configs = [] - for model in response["data"]: - assert "id" in model, f"Model missing 'id' field: {model}" - model_name = model["id"] - - if "type" not in model: - warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") - continue - elif model["type"] not in ["embeddings"]: - continue - - if "max_context_length" in model: - context_window_size = model["max_context_length"] - else: - warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") - continue - - configs.append( - EmbeddingConfig( - embedding_model=model_name, - embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, - embedding_dim=context_window_size, - embedding_chunk_size=300, # NOTE: max is 2048 - handle=self.get_handle(model_name), - ), - ) - - return configs - - -class XAIProvider(OpenAIProvider): - """https://docs.x.ai/docs/api-reference""" - - provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the xAI/Grok API.") - base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") - - def get_model_context_window_size(self, model_name: str) -> Optional[int]: - # xAI doesn't return context window in the model listing, - # so these are hardcoded from their website - if model_name == "grok-2-1212": - return 131072 - # NOTE: disabling the minis for now since they return weird MM parts - # elif model_name == "grok-3-mini-fast-beta": - # return 131072 - # elif model_name == "grok-3-mini-beta": - # return 131072 - elif model_name == "grok-3-fast-beta": - return 131072 - elif model_name == "grok-3-beta": - return 131072 - else: - return None - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list - - response = openai_get_model_list(self.base_url, api_key=self.api_key) - - if "data" in response: - data = response["data"] - else: - data = response - - configs = [] - for model in data: - assert "id" in model, f"xAI/Grok model missing 'id' field: {model}" - model_name = model["id"] - - # In case xAI starts supporting it in the future: - if "context_length" in model: - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) - - if not context_window_size: - warnings.warn(f"Couldn't find context window size for model {model_name}") - continue - - configs.append( - LLMConfig( - model=model_name, - model_endpoint_type="xai", - model_endpoint=self.base_url, - context_window=context_window_size, - handle=self.get_handle(model_name), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # No embeddings supported - return [] - - -class AnthropicProvider(Provider): - provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the Anthropic API.") - base_url: str = "https://api.anthropic.com/v1" - - def check_api_key(self): - from letta.llm_api.anthropic import anthropic_check_valid_api_key - - anthropic_check_valid_api_key(self.api_key) - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.anthropic import anthropic_get_model_list - - models = anthropic_get_model_list(api_key=self.api_key) - return self._list_llm_models(models) - - async def list_llm_models_async(self) -> List[LLMConfig]: - from letta.llm_api.anthropic import anthropic_get_model_list_async - - models = await anthropic_get_model_list_async(api_key=self.api_key) - return self._list_llm_models(models) - - def _list_llm_models(self, models) -> List[LLMConfig]: - from letta.llm_api.anthropic import MODEL_LIST - - configs = [] - for model in models: - - if model["type"] != "model": - continue - - if "id" not in model: - continue - - # Don't support 2.0 and 2.1 - if model["id"].startswith("claude-2"): - continue - - # Anthropic doesn't return the context window in their API - if "context_window" not in model: - # Remap list to name: context_window - model_library = {m["name"]: m["context_window"] for m in MODEL_LIST} - # Attempt to look it up in a hardcoded list - if model["id"] in model_library: - model["context_window"] = model_library[model["id"]] - else: - # On fallback, we can set 200k (generally safe), but we should warn the user - warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000") - model["context_window"] = 200000 - - max_tokens = 8192 - if "claude-3-opus" in model["id"]: - max_tokens = 4096 - if "claude-3-haiku" in model["id"]: - max_tokens = 4096 - # TODO: set for 3-7 extended thinking mode - - # We set this to false by default, because Anthropic can - # natively support tags inside of content fields - # However, putting COT inside of tool calls can make it more - # reliable for tool calling (no chance of a non-tool call step) - # Since tool_choice_type 'any' doesn't work with in-content COT - # NOTE For Haiku, it can be flaky if we don't enable this by default - # inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False - inner_thoughts_in_kwargs = True # we no longer support thinking tags - - configs.append( - LLMConfig( - model=model["id"], - model_endpoint_type="anthropic", - model_endpoint=self.base_url, - context_window=model["context_window"], - handle=self.get_handle(model["id"]), - put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, - max_tokens=max_tokens, - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - -class MistralProvider(Provider): - provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the Mistral API.") - base_url: str = "https://api.mistral.ai/v1" - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.mistral import mistral_get_model_list - - # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... - # See: https://openrouter.ai/docs/requests - response = mistral_get_model_list(self.base_url, api_key=self.api_key) - - assert "data" in response, f"Mistral model query response missing 'data' field: {response}" - - configs = [] - for model in response["data"]: - # If model has chat completions and function calling enabled - if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]: - configs.append( - LLMConfig( - model=model["id"], - model_endpoint_type="openai", - model_endpoint=self.base_url, - context_window=model["max_context_length"], - handle=self.get_handle(model["id"]), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # Not supported for mistral - return [] - - def get_model_context_window(self, model_name: str) -> Optional[int]: - # Redoing this is fine because it's a pretty lightweight call - models = self.list_llm_models() - - for m in models: - if model_name in m["id"]: - return int(m["max_context_length"]) - - return None - - -class OllamaProvider(OpenAIProvider): - """Ollama provider that uses the native /api/generate endpoint - - See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion - """ - - provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = Field(..., description="Base URL for the Ollama API.") - api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") - default_prompt_formatter: str = Field( - ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API." - ) - - async def list_llm_models_async(self) -> List[LLMConfig]: - """Async version of list_llm_models below""" - endpoint = f"{self.base_url}/api/tags" - async with aiohttp.ClientSession() as session: - async with session.get(endpoint) as response: - if response.status != 200: - raise Exception(f"Failed to list Ollama models: {response.text}") - response_json = await response.json() - - configs = [] - for model in response_json["models"]: - context_window = self.get_model_context_window(model["name"]) - if context_window is None: - print(f"Ollama model {model['name']} has no context window") - continue - configs.append( - LLMConfig( - model=model["name"], - model_endpoint_type="ollama", - model_endpoint=self.base_url, - model_wrapper=self.default_prompt_formatter, - context_window=context_window, - handle=self.get_handle(model["name"]), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - def list_llm_models(self) -> List[LLMConfig]: - # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models - response = requests.get(f"{self.base_url}/api/tags") - if response.status_code != 200: - raise Exception(f"Failed to list Ollama models: {response.text}") - response_json = response.json() - - configs = [] - for model in response_json["models"]: - context_window = self.get_model_context_window(model["name"]) - if context_window is None: - print(f"Ollama model {model['name']} has no context window") - continue - configs.append( - LLMConfig( - model=model["name"], - model_endpoint_type="ollama", - model_endpoint=self.base_url, - model_wrapper=self.default_prompt_formatter, - context_window=context_window, - handle=self.get_handle(model["name"]), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - def get_model_context_window(self, model_name: str) -> Optional[int]: - response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) - response_json = response.json() - - ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675 - # possible_keys = [ - # # OPT - # "max_position_embeddings", - # # GPT-2 - # "n_positions", - # # MPT - # "max_seq_len", - # # ChatGLM2 - # "seq_length", - # # Command-R - # "model_max_length", - # # Others - # "max_sequence_length", - # "max_seq_length", - # "seq_len", - # ] - # max_position_embeddings - # parse model cards: nous, dolphon, llama - if "model_info" not in response_json: - if "error" in response_json: - print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") - return None - for key, value in response_json["model_info"].items(): - if "context_length" in key: - return value - return None - - def _get_model_embedding_dim(self, model_name: str): - response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) - response_json = response.json() - return self._get_model_embedding_dim_impl(response_json, model_name) - - async def _get_model_embedding_dim_async(self, model_name: str): - async with aiohttp.ClientSession() as session: - async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response: - response_json = await response.json() - return self._get_model_embedding_dim_impl(response_json, model_name) - - @staticmethod - def _get_model_embedding_dim_impl(response_json: dict, model_name: str): - if "model_info" not in response_json: - if "error" in response_json: - print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") - return None - for key, value in response_json["model_info"].items(): - if "embedding_length" in key: - return value - return None - - async def list_embedding_models_async(self) -> List[EmbeddingConfig]: - """Async version of list_embedding_models below""" - endpoint = f"{self.base_url}/api/tags" - async with aiohttp.ClientSession() as session: - async with session.get(endpoint) as response: - if response.status != 200: - raise Exception(f"Failed to list Ollama models: {response.text}") - response_json = await response.json() - - configs = [] - for model in response_json["models"]: - embedding_dim = await self._get_model_embedding_dim_async(model["name"]) - if not embedding_dim: - print(f"Ollama model {model['name']} has no embedding dimension") - continue - configs.append( - EmbeddingConfig( - embedding_model=model["name"], - embedding_endpoint_type="ollama", - embedding_endpoint=self.base_url, - embedding_dim=embedding_dim, - embedding_chunk_size=300, - handle=self.get_handle(model["name"], is_embedding=True), - ) - ) - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models - response = requests.get(f"{self.base_url}/api/tags") - if response.status_code != 200: - raise Exception(f"Failed to list Ollama models: {response.text}") - response_json = response.json() - - configs = [] - for model in response_json["models"]: - embedding_dim = self._get_model_embedding_dim(model["name"]) - if not embedding_dim: - print(f"Ollama model {model['name']} has no embedding dimension") - continue - configs.append( - EmbeddingConfig( - embedding_model=model["name"], - embedding_endpoint_type="ollama", - embedding_endpoint=self.base_url, - embedding_dim=embedding_dim, - embedding_chunk_size=300, - handle=self.get_handle(model["name"], is_embedding=True), - ) - ) - return configs - - -class GroqProvider(OpenAIProvider): - provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = "https://api.groq.com/openai/v1" - api_key: str = Field(..., description="API key for the Groq API.") - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list - - response = openai_get_model_list(self.base_url, api_key=self.api_key) - configs = [] - for model in response["data"]: - if not "context_window" in model: - continue - configs.append( - LLMConfig( - model=model["id"], - model_endpoint_type="groq", - model_endpoint=self.base_url, - context_window=model["context_window"], - handle=self.get_handle(model["id"]), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - return [] - - -class TogetherProvider(OpenAIProvider): - """TogetherAI provider that uses the /completions API - - TogetherAI can also be used via the /chat/completions API - by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key - and API URL, however /completions is preferred because their /chat/completions - function calling support is limited. - """ - - provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = "https://api.together.ai/v1" - api_key: str = Field(..., description="API key for the TogetherAI API.") - default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list - - models = openai_get_model_list(self.base_url, api_key=self.api_key) - return self._list_llm_models(models) - - async def list_llm_models_async(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list_async - - models = await openai_get_model_list_async(self.base_url, api_key=self.api_key) - return self._list_llm_models(models) - - def _list_llm_models(self, models) -> List[LLMConfig]: - pass - - # TogetherAI's response is missing the 'data' field - # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" - if "data" in models: - data = models["data"] - else: - data = models - - configs = [] - for model in data: - assert "id" in model, f"TogetherAI model missing 'id' field: {model}" - model_name = model["id"] - - if "context_length" in model: - # Context length is returned in OpenRouter as "context_length" - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) - - # We need the context length for embeddings too - if not context_window_size: - continue - - # Skip models that are too small for Letta - if context_window_size <= MIN_CONTEXT_WINDOW: - continue - - # TogetherAI includes the type, which we can use to filter for embedding models - if "type" in model and model["type"] not in ["chat", "language"]: - continue - - configs.append( - LLMConfig( - model=model_name, - model_endpoint_type="together", - model_endpoint=self.base_url, - model_wrapper=self.default_prompt_formatter, - context_window=context_window_size, - handle=self.get_handle(model_name), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # TODO renable once we figure out how to pass API keys through properly - return [] - - # from letta.llm_api.openai import openai_get_model_list - - # response = openai_get_model_list(self.base_url, api_key=self.api_key) - - # # TogetherAI's response is missing the 'data' field - # # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" - # if "data" in response: - # data = response["data"] - # else: - # data = response - - # configs = [] - # for model in data: - # assert "id" in model, f"TogetherAI model missing 'id' field: {model}" - # model_name = model["id"] - - # if "context_length" in model: - # # Context length is returned in OpenRouter as "context_length" - # context_window_size = model["context_length"] - # else: - # context_window_size = self.get_model_context_window_size(model_name) - - # if not context_window_size: - # continue - - # # TogetherAI includes the type, which we can use to filter out embedding models - # if "type" in model and model["type"] not in ["embedding"]: - # continue - - # configs.append( - # EmbeddingConfig( - # embedding_model=model_name, - # embedding_endpoint_type="openai", - # embedding_endpoint=self.base_url, - # embedding_dim=context_window_size, - # embedding_chunk_size=300, # TODO: change? - # ) - # ) - - # return configs - - -class GoogleAIProvider(Provider): - # gemini - provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the Google AI API.") - base_url: str = "https://generativelanguage.googleapis.com" - - def check_api_key(self): - from letta.llm_api.google_ai_client import google_ai_check_valid_api_key - - google_ai_check_valid_api_key(self.api_key) - - def list_llm_models(self): - from letta.llm_api.google_ai_client import google_ai_get_model_list - - model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) - model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] - model_options = [str(m["name"]) for m in model_options] - - # filter by model names - model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] - - # Add support for all gemini models - model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] - - configs = [] - for model in model_options: - configs.append( - LLMConfig( - model=model, - model_endpoint_type="google_ai", - model_endpoint=self.base_url, - context_window=self.get_model_context_window(model), - handle=self.get_handle(model), - max_tokens=8192, - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - - return configs - - async def list_llm_models_async(self): - import asyncio - - from letta.llm_api.google_ai_client import google_ai_get_model_list_async - - # Get and filter the model list - model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) - model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] - model_options = [str(m["name"]) for m in model_options] - - # filter by model names - model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] - - # Add support for all gemini models - model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] - - # Prepare tasks for context window lookups in parallel - async def create_config(model): - context_window = await self.get_model_context_window_async(model) - return LLMConfig( - model=model, - model_endpoint_type="google_ai", - model_endpoint=self.base_url, - context_window=context_window, - handle=self.get_handle(model), - max_tokens=8192, - provider_name=self.name, - provider_category=self.provider_category, - ) - - # Execute all config creation tasks concurrently - configs = await asyncio.gather(*[create_config(model) for model in model_options]) - - return configs - - def list_embedding_models(self): - from letta.llm_api.google_ai_client import google_ai_get_model_list - - # TODO: use base_url instead - model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) - return self._list_embedding_models(model_options) - - async def list_embedding_models_async(self): - from letta.llm_api.google_ai_client import google_ai_get_model_list_async - - # TODO: use base_url instead - model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) - return self._list_embedding_models(model_options) - - def _list_embedding_models(self, model_options): - # filter by 'generateContent' models - model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] - model_options = [str(m["name"]) for m in model_options] - model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] - - configs = [] - for model in model_options: - configs.append( - EmbeddingConfig( - embedding_model=model, - embedding_endpoint_type="google_ai", - embedding_endpoint=self.base_url, - embedding_dim=768, - embedding_chunk_size=300, # NOTE: max is 2048 - handle=self.get_handle(model, is_embedding=True), - batch_size=1024, - ) - ) - return configs - - def get_model_context_window(self, model_name: str) -> Optional[int]: - from letta.llm_api.google_ai_client import google_ai_get_model_context_window - - if model_name in LLM_MAX_TOKENS: - return LLM_MAX_TOKENS[model_name] - else: - return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) - - async def get_model_context_window_async(self, model_name: str) -> Optional[int]: - from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async - - if model_name in LLM_MAX_TOKENS: - return LLM_MAX_TOKENS[model_name] - else: - return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) - - -class GoogleVertexProvider(Provider): - provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") - google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH - - configs = [] - for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items(): - configs.append( - LLMConfig( - model=model, - model_endpoint_type="google_vertex", - model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", - context_window=context_length, - handle=self.get_handle(model), - max_tokens=8192, - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM - - configs = [] - for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items(): - configs.append( - EmbeddingConfig( - embedding_model=model, - embedding_endpoint_type="google_vertex", - embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", - embedding_dim=dim, - embedding_chunk_size=300, # NOTE: max is 2048 - handle=self.get_handle(model, is_embedding=True), - batch_size=1024, - ) - ) - return configs - - -class AzureProvider(Provider): - provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation - base_url: str = Field( - ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." - ) - api_key: str = Field(..., description="API key for the Azure API.") - api_version: str = Field(latest_api_version, description="API version for the Azure API") - - @model_validator(mode="before") - def set_default_api_version(cls, values): - """ - This ensures that api_version is always set to the default if None is passed in. - """ - if values.get("api_version") is None: - values["api_version"] = cls.model_fields["latest_api_version"].default - return values - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list - - model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version) - configs = [] - for model_option in model_options: - model_name = model_option["id"] - context_window_size = self.get_model_context_window(model_name) - model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) - configs.append( - LLMConfig( - model=model_name, - model_endpoint_type="azure", - model_endpoint=model_endpoint, - context_window=context_window_size, - handle=self.get_handle(model_name), - provider_name=self.name, - provider_category=self.provider_category, - ), - ) - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list - - model_options = azure_openai_get_embeddings_model_list( - self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True - ) - configs = [] - for model_option in model_options: - model_name = model_option["id"] - model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version) - configs.append( - EmbeddingConfig( - embedding_model=model_name, - embedding_endpoint_type="azure", - embedding_endpoint=model_endpoint, - embedding_dim=768, - embedding_chunk_size=300, # NOTE: max is 2048 - handle=self.get_handle(model_name), - batch_size=1024, - ), - ) - return configs - - def get_model_context_window(self, model_name: str) -> Optional[int]: - """ - This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model. - """ - context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None) - if context_window is None: - context_window = LLM_MAX_TOKENS.get(model_name, 4096) - return context_window - - -class VLLMChatCompletionsProvider(Provider): - """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" - - # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = Field(..., description="Base URL for the vLLM API.") - - def list_llm_models(self) -> List[LLMConfig]: - # not supported with vLLM - from letta.llm_api.openai import openai_get_model_list - - assert self.base_url, "base_url is required for vLLM provider" - response = openai_get_model_list(self.base_url, api_key=None) - - configs = [] - for model in response["data"]: - configs.append( - LLMConfig( - model=model["id"], - model_endpoint_type="openai", - model_endpoint=self.base_url, - context_window=model["max_model_len"], - handle=self.get_handle(model["id"]), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # not supported with vLLM - return [] - - -class VLLMCompletionsProvider(Provider): - """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" - - # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = Field(..., description="Base URL for the vLLM API.") - default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") - - def list_llm_models(self) -> List[LLMConfig]: - # not supported with vLLM - from letta.llm_api.openai import openai_get_model_list - - response = openai_get_model_list(self.base_url, api_key=None) - - configs = [] - for model in response["data"]: - configs.append( - LLMConfig( - model=model["id"], - model_endpoint_type="vllm", - model_endpoint=self.base_url, - model_wrapper=self.default_prompt_formatter, - context_window=model["max_model_len"], - handle=self.get_handle(model["id"]), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # not supported with vLLM - return [] - - -class CohereProvider(OpenAIProvider): - pass - - -class BedrockProvider(Provider): - provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") - provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - region: str = Field(..., description="AWS region for Bedrock") - - def check_api_key(self): - """Check if the Bedrock credentials are valid""" - from letta.errors import LLMAuthenticationError - from letta.llm_api.aws_bedrock import bedrock_get_model_list - - try: - # For BYOK providers, use the custom credentials - if self.provider_category == ProviderCategory.byok: - # If we can list models, the credentials are valid - bedrock_get_model_list( - region_name=self.region, - access_key_id=self.access_key, - secret_access_key=self.api_key, # api_key stores the secret access key - ) - else: - # For base providers, use default credentials - bedrock_get_model_list(region_name=self.region) - except Exception as e: - raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") - - def list_llm_models(self): - from letta.llm_api.aws_bedrock import bedrock_get_model_list - - models = bedrock_get_model_list(self.region) - - configs = [] - for model_summary in models: - model_arn = model_summary["inferenceProfileArn"] - configs.append( - LLMConfig( - model=model_arn, - model_endpoint_type=self.provider_type.value, - model_endpoint=None, - context_window=self.get_model_context_window(model_arn), - handle=self.get_handle(model_arn), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - return configs - - async def list_llm_models_async(self) -> List[LLMConfig]: - from letta.llm_api.aws_bedrock import bedrock_get_model_list_async - - models = await bedrock_get_model_list_async( - self.access_key, - self.api_key, - self.region, - ) - - configs = [] - for model_summary in models: - model_arn = model_summary["inferenceProfileArn"] - configs.append( - LLMConfig( - model=model_arn, - model_endpoint_type=self.provider_type.value, - model_endpoint=None, - context_window=self.get_model_context_window(model_arn), - handle=self.get_handle(model_arn), - provider_name=self.name, - provider_category=self.provider_category, - ) - ) - - return configs - - def list_embedding_models(self): - return [] - - def get_model_context_window(self, model_name: str) -> Optional[int]: - # Context windows for Claude models - from letta.llm_api.aws_bedrock import bedrock_get_model_context_window - - return bedrock_get_model_context_window(model_name) - - def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: - print(model_name) - model = model_name.split(".")[-1] - return f"{self.name}/{model}" diff --git a/letta/schemas/providers/__init__.py b/letta/schemas/providers/__init__.py new file mode 100644 index 00000000..cc3cbd69 --- /dev/null +++ b/letta/schemas/providers/__init__.py @@ -0,0 +1,47 @@ +# Provider base classes and utilities +# Provider implementations +from .anthropic import AnthropicProvider +from .azure import AzureProvider +from .base import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate +from .bedrock import BedrockProvider +from .cerebras import CerebrasProvider +from .cohere import CohereProvider +from .deepseek import DeepSeekProvider +from .google_gemini import GoogleAIProvider +from .google_vertex import GoogleVertexProvider +from .groq import GroqProvider +from .letta import LettaProvider +from .lmstudio import LMStudioOpenAIProvider +from .mistral import MistralProvider +from .ollama import OllamaProvider +from .openai import OpenAIProvider +from .together import TogetherProvider +from .vllm import VLLMProvider +from .xai import XAIProvider + +__all__ = [ + # Base classes + "Provider", + "ProviderBase", + "ProviderCreate", + "ProviderUpdate", + "ProviderCheck", + # Provider implementations + "AnthropicProvider", + "AzureProvider", + "BedrockProvider", + "CerebrasProvider", # NEW + "CohereProvider", + "DeepSeekProvider", + "GoogleAIProvider", + "GoogleVertexProvider", + "GroqProvider", + "LettaProvider", + "LMStudioOpenAIProvider", + "MistralProvider", + "OllamaProvider", + "OpenAIProvider", + "TogetherProvider", + "VLLMProvider", # Replaces ChatCompletions and Completions + "XAIProvider", +] diff --git a/letta/schemas/providers/anthropic.py b/letta/schemas/providers/anthropic.py new file mode 100644 index 00000000..eac4c90d --- /dev/null +++ b/letta/schemas/providers/anthropic.py @@ -0,0 +1,78 @@ +import warnings +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +class AnthropicProvider(Provider): + provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the Anthropic API.") + base_url: str = "https://api.anthropic.com/v1" + + async def check_api_key(self): + from letta.llm_api.anthropic import anthropic_check_valid_api_key + + anthropic_check_valid_api_key(self.api_key) + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.anthropic import anthropic_get_model_list_async + + models = await anthropic_get_model_list_async(api_key=self.api_key) + return self._list_llm_models(models) + + def _list_llm_models(self, models) -> list[LLMConfig]: + from letta.llm_api.anthropic import MODEL_LIST + + configs = [] + for model in models: + if any((model.get("type") != "model", "id" not in model, model.get("id").startswith("claude-2"))): + continue + + # Anthropic doesn't return the context window in their API + if "context_window" not in model: + # Remap list to name: context_window + model_library = {m["name"]: m["context_window"] for m in MODEL_LIST} + # Attempt to look it up in a hardcoded list + if model["id"] in model_library: + model["context_window"] = model_library[model["id"]] + else: + # On fallback, we can set 200k (generally safe), but we should warn the user + warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000") + model["context_window"] = 200000 + + max_tokens = 8192 + if "claude-3-opus" in model["id"]: + max_tokens = 4096 + if "claude-3-haiku" in model["id"]: + max_tokens = 4096 + # TODO: set for 3-7 extended thinking mode + + # NOTE: from 2025-02 + # We set this to false by default, because Anthropic can + # natively support tags inside of content fields + # However, putting COT inside of tool calls can make it more + # reliable for tool calling (no chance of a non-tool call step) + # Since tool_choice_type 'any' doesn't work with in-content COT + # NOTE For Haiku, it can be flaky if we don't enable this by default + # inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False + inner_thoughts_in_kwargs = True # we no longer support thinking tags + + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="anthropic", + model_endpoint=self.base_url, + context_window=model["context_window"], + handle=self.get_handle(model["id"]), + put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, + max_tokens=max_tokens, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs diff --git a/letta/schemas/providers/azure.py b/letta/schemas/providers/azure.py new file mode 100644 index 00000000..e51c1775 --- /dev/null +++ b/letta/schemas/providers/azure.py @@ -0,0 +1,80 @@ +from typing import ClassVar, Literal + +from pydantic import Field, field_validator + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS +from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint +from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +class AzureProvider(Provider): + LATEST_API_VERSION: ClassVar[str] = "2024-09-01-preview" + + provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + # Note: 2024-09-01-preview was set here until 2025-07-16. + # set manually, see: https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation + latest_api_version: str = "2025-04-01-preview" + base_url: str = Field( + ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." + ) + api_key: str = Field(..., description="API key for the Azure API.") + api_version: str = Field(default=LATEST_API_VERSION, description="API version for the Azure API") + + @field_validator("api_version", mode="before") + def replace_none_with_default(cls, v): + return v if v is not None else cls.LATEST_API_VERSION + + async def list_llm_models_async(self) -> list[LLMConfig]: + # TODO (cliandy): asyncify + from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list + + model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version) + configs = [] + for model_option in model_options: + model_name = model_option["id"] + context_window_size = self.get_model_context_window(model_name) + model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="azure", + model_endpoint=model_endpoint, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + # TODO (cliandy): asyncify dependent function calls + from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list + + model_options = azure_openai_get_embeddings_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version) + configs = [] + for model_option in model_options: + model_name = model_option["id"] + model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version) + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type="azure", + embedding_endpoint=model_endpoint, + embedding_dim=768, # TODO generated 1536? + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # old note: max is 2048 + handle=self.get_handle(model_name, is_embedding=True), + batch_size=1024, + ) + ) + return configs + + def get_model_context_window(self, model_name: str) -> int | None: + # Hard coded as there are no API endpoints for this + llm_default = LLM_MAX_TOKENS.get(model_name, 4096) + return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default) diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py new file mode 100644 index 00000000..eef2cb39 --- /dev/null +++ b/letta/schemas/providers/base.py @@ -0,0 +1,201 @@ +from datetime import datetime + +from pydantic import BaseModel, Field, model_validator + +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.letta_base import LettaBase +from letta.schemas.llm_config import LLMConfig +from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES +from letta.settings import model_settings + + +class ProviderBase(LettaBase): + __id_prefix__ = "provider" + + +class Provider(ProviderBase): + id: str | None = Field(None, description="The id of the provider, lazily created by the database manager.") + name: str = Field(..., description="The name of the provider") + provider_type: ProviderType = Field(..., description="The type of the provider") + provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)") + api_key: str | None = Field(None, description="API key or secret key used for requests to the provider.") + base_url: str | None = Field(None, description="Base URL for the provider.") + access_key: str | None = Field(None, description="Access key used for requests to the provider.") + region: str | None = Field(None, description="Region used for requests to the provider.") + organization_id: str | None = Field(None, description="The organization id of the user") + updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.") + + @model_validator(mode="after") + def default_base_url(self): + if self.provider_type == ProviderType.openai and self.base_url is None: + self.base_url = model_settings.openai_api_base + return self + + def resolve_identifier(self): + if not self.id: + self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) + + async def check_api_key(self): + """Check if the API key is valid for the provider""" + raise NotImplementedError + + def list_llm_models(self) -> list[LLMConfig]: + """List available LLM models (deprecated: use list_llm_models_async)""" + import asyncio + import warnings + + warnings.warn("list_llm_models is deprecated, use list_llm_models_async instead", DeprecationWarning, stacklevel=2) + + # Simplified asyncio handling - just use asyncio.run() + # This works in most contexts and avoids complex event loop detection + try: + return asyncio.run(self.list_llm_models_async()) + except RuntimeError as e: + # If we're in an active event loop context, use a thread pool + if "cannot be called from a running event loop" in str(e): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.list_llm_models_async()) + return future.result() + else: + raise + + async def list_llm_models_async(self) -> list[LLMConfig]: + return [] + + def list_embedding_models(self) -> list[EmbeddingConfig]: + """List available embedding models (deprecated: use list_embedding_models_async)""" + import asyncio + import warnings + + warnings.warn("list_embedding_models is deprecated, use list_embedding_models_async instead", DeprecationWarning, stacklevel=2) + + # Simplified asyncio handling - just use asyncio.run() + # This works in most contexts and avoids complex event loop detection + try: + return asyncio.run(self.list_embedding_models_async()) + except RuntimeError as e: + # If we're in an active event loop context, use a thread pool + if "cannot be called from a running event loop" in str(e): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.list_embedding_models_async()) + return future.result() + else: + raise + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + """List available embedding models. The following do not have support for embedding models: + Anthropic, Bedrock, Cerebras, Deepseek, Groq, Mistral, xAI + """ + return [] + + def get_model_context_window(self, model_name: str) -> int | None: + raise NotImplementedError + + async def get_model_context_window_async(self, model_name: str) -> int | None: + raise NotImplementedError + + def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str: + """ + Get the handle for a model, with support for custom overrides. + + Args: + model_name (str): The name of the model. + is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False. + + Returns: + str: The handle for the model. + """ + base_name = base_name if base_name else self.name + + overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES + if base_name in overrides and model_name in overrides[base_name]: + model_name = overrides[base_name][model_name] + + return f"{base_name}/{model_name}" + + def cast_to_subtype(self): + # Import here to avoid circular imports + from letta.schemas.providers import ( + AnthropicProvider, + AzureProvider, + BedrockProvider, + CerebrasProvider, + CohereProvider, + DeepSeekProvider, + GoogleAIProvider, + GoogleVertexProvider, + GroqProvider, + LettaProvider, + LMStudioOpenAIProvider, + MistralProvider, + OllamaProvider, + OpenAIProvider, + TogetherProvider, + VLLMProvider, + XAIProvider, + ) + + match self.provider_type: + case ProviderType.letta: + return LettaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.openai: + return OpenAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.anthropic: + return AnthropicProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_ai: + return GoogleAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_vertex: + return GoogleVertexProvider(**self.model_dump(exclude_none=True)) + case ProviderType.azure: + return AzureProvider(**self.model_dump(exclude_none=True)) + case ProviderType.groq: + return GroqProvider(**self.model_dump(exclude_none=True)) + case ProviderType.together: + return TogetherProvider(**self.model_dump(exclude_none=True)) + case ProviderType.ollama: + return OllamaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm: + return VLLMProvider(**self.model_dump(exclude_none=True)) # Removed support for CompletionsProvider + case ProviderType.mistral: + return MistralProvider(**self.model_dump(exclude_none=True)) + case ProviderType.deepseek: + return DeepSeekProvider(**self.model_dump(exclude_none=True)) + case ProviderType.cerebras: + return CerebrasProvider(**self.model_dump(exclude_none=True)) + case ProviderType.xai: + return XAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.lmstudio_openai: + return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.bedrock: + return BedrockProvider(**self.model_dump(exclude_none=True)) + case ProviderType.cohere: + return CohereProvider(**self.model_dump(exclude_none=True)) + case _: + raise ValueError(f"Unknown provider type: {self.provider_type}") + + +class ProviderCreate(ProviderBase): + name: str = Field(..., description="The name of the provider.") + provider_type: ProviderType = Field(..., description="The type of the provider.") + api_key: str = Field(..., description="API key or secret key used for requests to the provider.") + access_key: str | None = Field(None, description="Access key used for requests to the provider.") + region: str | None = Field(None, description="Region used for requests to the provider.") + + +class ProviderUpdate(ProviderBase): + api_key: str = Field(..., description="API key or secret key used for requests to the provider.") + access_key: str | None = Field(None, description="Access key used for requests to the provider.") + region: str | None = Field(None, description="Region used for requests to the provider.") + + +class ProviderCheck(BaseModel): + provider_type: ProviderType = Field(..., description="The type of the provider.") + api_key: str = Field(..., description="API key or secret key used for requests to the provider.") + access_key: str | None = Field(None, description="Access key used for requests to the provider.") + region: str | None = Field(None, description="Region used for requests to the provider.") diff --git a/letta/schemas/providers/bedrock.py b/letta/schemas/providers/bedrock.py new file mode 100644 index 00000000..d7d8437f --- /dev/null +++ b/letta/schemas/providers/bedrock.py @@ -0,0 +1,78 @@ +""" +Note that this formally only supports Anthropic Bedrock. +TODO (cliandy): determine what other providers are supported and what is needed to add support. +""" + +from typing import Literal + +from pydantic import Field + +from letta.log import get_logger +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + +logger = get_logger(__name__) + + +class BedrockProvider(Provider): + provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + region: str = Field(..., description="AWS region for Bedrock") + + async def check_api_key(self): + """Check if the Bedrock credentials are valid""" + from letta.errors import LLMAuthenticationError + from letta.llm_api.aws_bedrock import bedrock_get_model_list_async + + try: + # For BYOK providers, use the custom credentials + if self.provider_category == ProviderCategory.byok: + # If we can list models, the credentials are valid + await bedrock_get_model_list_async( + access_key_id=self.access_key, + secret_access_key=self.api_key, # api_key stores the secret access key + region_name=self.region, + ) + else: + # For base providers, use default credentials + bedrock_get_model_list(region_name=self.region) + except Exception as e: + raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.aws_bedrock import bedrock_get_model_list_async + + models = await bedrock_get_model_list_async( + self.access_key, + self.api_key, + self.region, + ) + + configs = [] + for model_summary in models: + model_arn = model_summary["inferenceProfileArn"] + configs.append( + LLMConfig( + model=model_arn, + model_endpoint_type=self.provider_type.value, + model_endpoint=None, + context_window=self.get_model_context_window(model_arn), + handle=self.get_handle(model_arn), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + def get_model_context_window(self, model_name: str) -> int | None: + # Context windows for Claude models + from letta.llm_api.aws_bedrock import bedrock_get_model_context_window + + return bedrock_get_model_context_window(model_name) + + def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str: + logger.debug("Getting handle for model_name: %s", model_name) + model = model_name.split(".")[-1] + return f"{self.name}/{model}" diff --git a/letta/schemas/providers/cerebras.py b/letta/schemas/providers/cerebras.py new file mode 100644 index 00000000..173dc4ba --- /dev/null +++ b/letta/schemas/providers/cerebras.py @@ -0,0 +1,79 @@ +import warnings +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + + +class CerebrasProvider(OpenAIProvider): + """ + Cerebras Inference API is OpenAI-compatible and focuses on ultra-fast inference. + + Available Models (as of 2025): + - llama-4-scout-17b-16e-instruct: Llama 4 Scout (109B params, 10M context, ~2600 tokens/s) + - llama3.1-8b: Llama 3.1 8B (8B params, 128K context, ~2200 tokens/s) + - llama-3.3-70b: Llama 3.3 70B (70B params, 128K context, ~2100 tokens/s) + - qwen-3-32b: Qwen 3 32B (32B params, 131K context, ~2100 tokens/s) + - deepseek-r1-distill-llama-70b: DeepSeek R1 Distill (70B params, 128K context, ~1700 tokens/s) + """ + + provider_type: Literal[ProviderType.cerebras] = Field(ProviderType.cerebras, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field("https://api.cerebras.ai/v1", description="Base URL for the Cerebras API.") + api_key: str = Field(..., description="API key for the Cerebras API.") + + def get_model_context_window_size(self, model_name: str) -> int | None: + """Cerebras has limited context window sizes. + + see https://inference-docs.cerebras.ai/support/pricing for details by plan + """ + is_free_tier = True + if is_free_tier: + return 8192 + return 128000 + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + + if "data" in response: + data = response["data"] + else: + data = response + + configs = [] + for model in data: + assert "id" in model, f"Cerebras model missing 'id' field: {model}" + model_name = model["id"] + + # Check if model has context_length in response + if "context_length" in model: + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + warnings.warn(f"Couldn't find context window size for model {model_name}") + continue + + # Cerebras supports function calling + put_inner_thoughts_in_kwargs = True + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="openai", # Cerebras uses OpenAI-compatible endpoint + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs diff --git a/letta/schemas/providers/cohere.py b/letta/schemas/providers/cohere.py new file mode 100644 index 00000000..ce3d6150 --- /dev/null +++ b/letta/schemas/providers/cohere.py @@ -0,0 +1,18 @@ +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + + +# TODO (cliandy): this needs to be implemented +class CohereProvider(OpenAIProvider): + provider_type: Literal[ProviderType.cohere] = Field(ProviderType.cohere, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = "" + api_key: str = Field(..., description="API key for the Cohere API.") + + async def list_llm_models_async(self) -> list[LLMConfig]: + raise NotImplementedError diff --git a/letta/schemas/providers/deepseek.py b/letta/schemas/providers/deepseek.py new file mode 100644 index 00000000..0c1ae0c2 --- /dev/null +++ b/letta/schemas/providers/deepseek.py @@ -0,0 +1,63 @@ +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + + +class DeepSeekProvider(OpenAIProvider): + """ + DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, + but with slight differences: + * For example, DeepSeek's API requires perfect interleaving of user/assistant + * It also does not support native function calling + """ + + provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") + api_key: str = Field(..., description="API key for the DeepSeek API.") + + # TODO (cliandy): this may need to be updated to reflect current models + def get_model_context_window_size(self, model_name: str) -> int | None: + # DeepSeek doesn't return context window in the model listing, + # so these are hardcoded from their website + if model_name == "deepseek-reasoner": + return 64000 + elif model_name == "deepseek-chat": + return 64000 + else: + return None + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + data = response.get("data", response) + + configs = [] + for model in data: + check = self._do_model_checks_for_name_and_context_size(model) + if check is None: + continue + model_name, context_window_size = check + + # Not used for deepseek-reasoner, but otherwise is true + put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="deepseek", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs diff --git a/letta/schemas/providers/google_gemini.py b/letta/schemas/providers/google_gemini.py new file mode 100644 index 00000000..6404e0fc --- /dev/null +++ b/letta/schemas/providers/google_gemini.py @@ -0,0 +1,102 @@ +import asyncio +from typing import Literal + +from pydantic import Field + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +class GoogleAIProvider(Provider): + provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the Google AI API.") + base_url: str = "https://generativelanguage.googleapis.com" + + async def check_api_key(self): + from letta.llm_api.google_ai_client import google_ai_check_valid_api_key + + google_ai_check_valid_api_key(self.api_key) + + async def list_llm_models_async(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # Get and filter the model list + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + + # filter by model names + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + # Add support for all gemini models + model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] + + # Prepare tasks for context window lookups in parallel + async def create_config(model): + context_window = await self.get_model_context_window_async(model) + return LLMConfig( + model=model, + model_endpoint_type="google_ai", + model_endpoint=self.base_url, + context_window=context_window, + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + + # Execute all config creation tasks concurrently + configs = await asyncio.gather(*[create_config(model) for model in model_options]) + + return configs + + async def list_embedding_models_async(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # TODO: use base_url instead + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + def _list_embedding_models(self, model_options): + # filter by 'generateContent' models + model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + configs = [] + for model in model_options: + configs.append( + EmbeddingConfig( + embedding_model=model, + embedding_endpoint_type="google_ai", + embedding_endpoint=self.base_url, + embedding_dim=768, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048 + handle=self.get_handle(model, is_embedding=True), + batch_size=1024, + ) + ) + return configs + + def get_model_context_window(self, model_name: str) -> int | None: + import warnings + + warnings.warn("This is deprecated, use get_model_context_window_async when possible.", DeprecationWarning) + from letta.llm_api.google_ai_client import google_ai_get_model_context_window + + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) + + async def get_model_context_window_async(self, model_name: str) -> int | None: + from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async + + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) diff --git a/letta/schemas/providers/google_vertex.py b/letta/schemas/providers/google_vertex.py new file mode 100644 index 00000000..0ed68541 --- /dev/null +++ b/letta/schemas/providers/google_vertex.py @@ -0,0 +1,54 @@ +from typing import Literal + +from pydantic import Field + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +# TODO (cliandy): GoogleVertexProvider uses hardcoded models vs Gemini fetches from API +class GoogleVertexProvider(Provider): + provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") + google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH + + configs = [] + for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items(): + configs.append( + LLMConfig( + model=model, + model_endpoint_type="google_vertex", + model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", + context_window=context_length, + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM + + configs = [] + for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items(): + configs.append( + EmbeddingConfig( + embedding_model=model, + embedding_endpoint_type="google_vertex", + embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", + embedding_dim=dim, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048 + handle=self.get_handle(model, is_embedding=True), + batch_size=1024, + ) + ) + return configs diff --git a/letta/schemas/providers/groq.py b/letta/schemas/providers/groq.py new file mode 100644 index 00000000..18b4cb31 --- /dev/null +++ b/letta/schemas/providers/groq.py @@ -0,0 +1,35 @@ +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + + +class GroqProvider(OpenAIProvider): + provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = "https://api.groq.com/openai/v1" + api_key: str = Field(..., description="API key for the Groq API.") + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + configs = [] + for model in response["data"]: + if "context_window" not in model: + continue + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="groq", + model_endpoint=self.base_url, + context_window=model["context_window"], + handle=self.get_handle(model["id"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs diff --git a/letta/schemas/providers/letta.py b/letta/schemas/providers/letta.py new file mode 100644 index 00000000..37763884 --- /dev/null +++ b/letta/schemas/providers/letta.py @@ -0,0 +1,39 @@ +from typing import Literal + +from pydantic import Field + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +class LettaProvider(Provider): + provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + + async def list_llm_models_async(self) -> list[LLMConfig]: + return [ + LLMConfig( + model="letta-free", # NOTE: renamed + model_endpoint_type="openai", + model_endpoint=LETTA_MODEL_ENDPOINT, + context_window=30000, + handle=self.get_handle("letta-free"), + provider_name=self.name, + provider_category=self.provider_category, + ) + ] + + async def list_embedding_models_async(self): + return [ + EmbeddingConfig( + embedding_model="letta-free", # NOTE: renamed + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_dim=1024, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle("letta-free", is_embedding=True), + ) + ] diff --git a/letta/schemas/providers/lmstudio.py b/letta/schemas/providers/lmstudio.py new file mode 100644 index 00000000..6e4a639a --- /dev/null +++ b/letta/schemas/providers/lmstudio.py @@ -0,0 +1,97 @@ +import warnings +from typing import Literal + +from pydantic import Field + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + + +class LMStudioOpenAIProvider(OpenAIProvider): + provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") + api_key: str | None = Field(None, description="API key for the LMStudio API.") + + @property + def model_endpoint_url(self): + # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' + return f"{self.base_url.strip('/v1')}/api/v0" + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + response = await openai_get_model_list_async(self.model_endpoint_url) + + if "data" not in response: + warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") + return [] + + configs = [] + for model in response["data"]: + model_type = model.get("type") + if not model_type: + warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") + continue + if model_type not in ("vlm", "llm"): + continue + + # TODO (cliandy): previously we didn't get the backup context size, is this valid? + check = self._do_model_checks_for_name_and_context_size(model) + if check is None: + continue + model_name, context_window_size = check + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + response = await openai_get_model_list_async(self.model_endpoint_url) + + if "data" not in response: + warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") + return [] + + configs = [] + for model in response["data"]: + model_type = model.get("type") + if not model_type: + warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") + continue + if model_type not in ("embeddings"): + continue + + # TODO (cliandy): previously we didn't get the backup context size, is this valid? + check = self._do_model_checks_for_name_and_context_size(model, length_key="max_context_length") + if check is None: + continue + model_name, context_window_size = check + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=768, # Default embedding dimension, not context window + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048 + handle=self.get_handle(model_name), + ), + ) + + return configs diff --git a/letta/schemas/providers/mistral.py b/letta/schemas/providers/mistral.py new file mode 100644 index 00000000..2eeb3a23 --- /dev/null +++ b/letta/schemas/providers/mistral.py @@ -0,0 +1,41 @@ +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +class MistralProvider(Provider): + provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the Mistral API.") + base_url: str = "https://api.mistral.ai/v1" + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.mistral import mistral_get_model_list_async + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + response = await mistral_get_model_list_async(self.base_url, api_key=self.api_key) + + assert "data" in response, f"Mistral model query response missing 'data' field: {response}" + + configs = [] + for model in response["data"]: + # If model has chat completions and function calling enabled + if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=model["max_context_length"], + handle=self.get_handle(model["id"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs diff --git a/letta/schemas/providers/ollama.py b/letta/schemas/providers/ollama.py new file mode 100644 index 00000000..b9ddaa2c --- /dev/null +++ b/letta/schemas/providers/ollama.py @@ -0,0 +1,151 @@ +from typing import Literal + +import aiohttp +import requests +from pydantic import Field + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE +from letta.log import get_logger +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + +logger = get_logger(__name__) + + +class OllamaProvider(OpenAIProvider): + """Ollama provider that uses the native /api/generate endpoint + + See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion + """ + + provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the Ollama API.") + api_key: str | None = Field(None, description="API key for the Ollama API (default: `None`).") + default_prompt_formatter: str = Field( + ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API." + ) + + async def list_llm_models_async(self) -> list[LLMConfig]: + """List available LLM Models from Ollama + + https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models""" + endpoint = f"{self.base_url}/api/tags" + async with aiohttp.ClientSession() as session: + async with session.get(endpoint) as response: + if response.status != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = await response.json() + + configs = [] + for model in response_json["models"]: + context_window = self.get_model_context_window(model["name"]) + if context_window is None: + print(f"Ollama model {model['name']} has no context window") + continue + configs.append( + LLMConfig( + model=model["name"], + model_endpoint_type="ollama", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=context_window, + handle=self.get_handle(model["name"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + """List available embedding models from Ollama + + https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models + """ + endpoint = f"{self.base_url}/api/tags" + async with aiohttp.ClientSession() as session: + async with session.get(endpoint) as response: + if response.status != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = await response.json() + + configs = [] + for model in response_json["models"]: + embedding_dim = await self._get_model_embedding_dim_async(model["name"]) + if not embedding_dim: + print(f"Ollama model {model['name']} has no embedding dimension") + continue + configs.append( + EmbeddingConfig( + embedding_model=model["name"], + embedding_endpoint_type="ollama", + embedding_endpoint=self.base_url, + embedding_dim=embedding_dim, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle(model["name"], is_embedding=True), + ) + ) + return configs + + def get_model_context_window(self, model_name: str) -> int | None: + """Gets model context window for Ollama. As this can look different based on models, + we use the following for guidance: + + "llama.context_length": 8192, + "llama.embedding_length": 4096, + source: https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information + + FROM 2024-10-08 + Notes from vLLM around keys + source: https://github.com/vllm-project/vllm/blob/72ad2735823e23b4e1cc79b7c73c3a5f3c093ab0/vllm/config.py#L3488 + + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + max_position_embeddings + parse model cards: nous, dolphon, llama + """ + endpoint = f"{self.base_url}/api/show" + payload = {"name": model_name, "verbose": True} + response = requests.post(endpoint, json=payload) + if response.status_code != 200: + return None + + try: + model_info = response.json() + # Try to extract context window from model parameters + if "model_info" in model_info and "llama.context_length" in model_info["model_info"]: + return int(model_info["model_info"]["llama.context_length"]) + except Exception: + pass + logger.warning(f"Failed to get model context window for {model_name}") + return None + + async def _get_model_embedding_dim_async(self, model_name: str): + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response: + response_json = await response.json() + + if "model_info" not in response_json: + if "error" in response_json: + logger.warning("Ollama fetch model info error for %s: %s", model_name, response_json["error"]) + return None + + return response_json["model_info"].get("embedding_length") diff --git a/letta/schemas/providers/openai.py b/letta/schemas/providers/openai.py new file mode 100644 index 00000000..2a3bc0b8 --- /dev/null +++ b/letta/schemas/providers/openai.py @@ -0,0 +1,241 @@ +from typing import Literal + +from pydantic import Field + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS +from letta.log import get_logger +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + +logger = get_logger(__name__) + +ALLOWED_PREFIXES = {"gpt-4", "o1", "o3", "o4"} +DISALLOWED_KEYWORDS = {"transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"} +DEFAULT_EMBEDDING_BATCH_SIZE = 1024 + + +class OpenAIProvider(Provider): + provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the OpenAI API.") + base_url: str = Field(..., description="Base URL for the OpenAI API.") + + async def check_api_key(self): + from letta.llm_api.openai import openai_check_valid_api_key + + openai_check_valid_api_key(self.base_url, self.api_key) + + async def _get_models_async(self) -> list[dict]: + from letta.llm_api.openai import openai_get_model_list_async + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None + + # Similar to Nebius + extra_params = {"verbose": True} if "nebius.com" in self.base_url else None + + response = await openai_get_model_list_async( + self.base_url, + api_key=self.api_key, + extra_params=extra_params, + # fix_url=True, # NOTE: make sure together ends with /v1 + ) + + # TODO (cliandy): this is brittle as TogetherAI seems to result in a list instead of having a 'data' field + data = response.get("data", response) + assert isinstance(data, list) + return data + + async def list_llm_models_async(self) -> list[LLMConfig]: + data = await self._get_models_async() + return self._list_llm_models(data) + + def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]: + """ + This handles filtering out LLM Models by provider that meet Letta's requirements. + """ + configs = [] + for model in data: + check = self._do_model_checks_for_name_and_context_size(model) + if check is None: + continue + model_name, context_window_size = check + + # ===== Provider filtering ===== + # TogetherAI: includes the type, which we can use to filter out embedding models + if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url: + if "type" in model and model["type"] not in ["chat", "language"]: + continue + + # for TogetherAI, we need to skip the models that don't support JSON mode / function calling + # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: { + # "error": { + # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling", + # "type": "invalid_request_error", + # "param": null, + # "code": "constraints_model" + # } + # } + if "config" not in model: + continue + + # Nebius: includes the type, which we can use to filter for text models + if "nebius.com" in self.base_url: + model_type = model.get("architecture", {}).get("modality") + if model_type not in ["text->text", "text+image->text"]: + continue + + # OpenAI + # NOTE: o1-mini and o1-preview do not support tool calling + # NOTE: o1-mini does not support system messages + # NOTE: o1-pro is only available in Responses API + if self.base_url == "https://api.openai.com/v1": + if any(keyword in model_name for keyword in DISALLOWED_KEYWORDS) or not any( + model_name.startswith(prefix) for prefix in ALLOWED_PREFIXES + ): + continue + + # We'll set the model endpoint based on the base URL + # Note: openai-proxy just means that the model is using the OpenAIProvider + if self.base_url != "https://api.openai.com/v1": + handle = self.get_handle(model_name, base_name="openai-proxy") + else: + handle = self.get_handle(model_name) + + config = LLMConfig( + model=model_name, + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=handle, + provider_name=self.name, + provider_category=self.provider_category, + ) + + config = self._set_model_parameter_tuned_defaults(model_name, config) + configs.append(config) + + # for OpenAI, sort in reverse order + if self.base_url == "https://api.openai.com/v1": + configs.sort(key=lambda x: x.model, reverse=True) + return configs + + def _do_model_checks_for_name_and_context_size(self, model: dict, length_key: str = "context_length") -> tuple[str, int] | None: + if "id" not in model: + logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model) + return None + + model_name = model["id"] + context_window_size = model.get(length_key) or self.get_model_context_window_size(model_name) + + if not context_window_size: + logger.info("No context window size found for model: %s", model_name) + return None + + return model_name, context_window_size + + @staticmethod + def _set_model_parameter_tuned_defaults(model_name: str, llm_config: LLMConfig): + """This function is used to tune LLMConfig parameters to improve model performance.""" + + # gpt-4o-mini has started to regress with pretty bad emoji spam loops (2025-07) + if "gpt-4o-mini" in model_name or "gpt-4.1-mini" in model_name: + llm_config.frequency_penalty = 1.0 + return llm_config + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + if self.base_url == "https://api.openai.com/v1": + # TODO: actually automatically list models for OpenAI + return [ + EmbeddingConfig( + embedding_model="text-embedding-ada-002", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=1536, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle("text-embedding-ada-002", is_embedding=True), + batch_size=DEFAULT_EMBEDDING_BATCH_SIZE, + ), + EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle("text-embedding-3-small", is_embedding=True), + batch_size=DEFAULT_EMBEDDING_BATCH_SIZE, + ), + EmbeddingConfig( + embedding_model="text-embedding-3-large", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle("text-embedding-3-large", is_embedding=True), + batch_size=DEFAULT_EMBEDDING_BATCH_SIZE, + ), + ] + else: + # TODO: this has filtering that doesn't apply for embedding models, fix this. + data = await self._get_models_async() + return self._list_embedding_models(data) + + def _list_embedding_models(self, data) -> list[EmbeddingConfig]: + configs = [] + for model in data: + check = self._do_model_checks_for_name_and_context_size(model) + if check is None: + continue + model_name, context_window_size = check + + # ===== Provider filtering ===== + # TogetherAI: includes the type, which we can use to filter for embedding models + if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url: + if "type" in model and model["type"] not in ["embedding"]: + continue + # Nebius: includes the type, which we can use to filter for text models + elif "nebius.com" in self.base_url: + model_type = model.get("architecture", {}).get("modality") + if model_type not in ["text->embedding"]: + continue + else: + logger.info( + f"Skipping embedding models for %s by default, as we don't assume embeddings are supported." + "Please open an issue on GitHub if support is required.", + self.base_url, + ) + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type=self.provider_type, + embedding_endpoint=self.base_url, + embedding_dim=context_window_size, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle(model, is_embedding=True), + ) + ) + + return configs + + def get_model_context_window_size(self, model_name: str) -> int | None: + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + logger.debug( + f"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {{LLM_MAX_TOKENS['DEFAULT']}}", + model_name, + self.base_url, + self.__class__.__name__, + ) + return LLM_MAX_TOKENS["DEFAULT"] + + def get_model_context_window(self, model_name: str) -> int | None: + return self.get_model_context_window_size(model_name) + + async def get_model_context_window_async(self, model_name: str) -> int | None: + return self.get_model_context_window_size(model_name) diff --git a/letta/schemas/providers/together.py b/letta/schemas/providers/together.py new file mode 100644 index 00000000..8a7a000e --- /dev/null +++ b/letta/schemas/providers/together.py @@ -0,0 +1,85 @@ +""" +Note: this supports completions (deprecated by openai) and chat completions via the OpenAI API. +""" + +from typing import Literal + +from pydantic import Field + +from letta.constants import MIN_CONTEXT_WINDOW +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + + +class TogetherProvider(OpenAIProvider): + provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = "https://api.together.xyz/v1" + api_key: str = Field(..., description="API key for the Together API.") + default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + models = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + return self._list_llm_models(models) + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + import warnings + + warnings.warn( + "Letta does not currently support listing embedding models for Together. Please " + "contact support or reach out via GitHub or Discord to get support." + ) + return [] + + # TODO (cliandy): verify this with openai + def _list_llm_models(self, models) -> list[LLMConfig]: + pass + + # TogetherAI's response is missing the 'data' field + # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" + if "data" in models: + data = models["data"] + else: + data = models + + configs = [] + for model in data: + assert "id" in model, f"TogetherAI model missing 'id' field: {model}" + model_name = model["id"] + + if "context_length" in model: + # Context length is returned in OpenRouter as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + # We need the context length for embeddings too + if not context_window_size: + continue + + # Skip models that are too small for Letta + if context_window_size <= MIN_CONTEXT_WINDOW: + continue + + # TogetherAI includes the type, which we can use to filter for embedding models + if "type" in model and model["type"] not in ["chat", "language"]: + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="together", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs diff --git a/letta/schemas/providers/vllm.py b/letta/schemas/providers/vllm.py new file mode 100644 index 00000000..2f261c3e --- /dev/null +++ b/letta/schemas/providers/vllm.py @@ -0,0 +1,57 @@ +""" +Note: this consolidates the vLLM provider for completions (deprecated by openai) +and chat completions. Support is provided primarily for the chat completions endpoint, +but to utilize the completions endpoint, set the proper `base_url` and +`default_prompt_formatter`. +""" + +from typing import Literal + +from pydantic import Field + +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider + + +class VLLMProvider(Provider): + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the vLLM API.") + api_key: str | None = Field(None, description="API key for the vLLM API.") + default_prompt_formatter: str | None = Field( + default=None, description="Default prompt formatter (aka model wrapper) to use on a /completions style API." + ) + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + # TODO (cliandy): previously unsupported with vLLM; confirm if this is still the case or not + response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + + data = response.get("data", response) + + configs = [] + for model in data: + model_name = model["id"] + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="openai", # TODO (cliandy): this was previous vllm for the completions provider, why? + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=model["max_model_len"], + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + # Note: vLLM technically can support embedding models though may require multiple instances + # for now, we will not support embedding models for vLLM. + return [] diff --git a/letta/schemas/providers/xai.py b/letta/schemas/providers/xai.py new file mode 100644 index 00000000..d042aad0 --- /dev/null +++ b/letta/schemas/providers/xai.py @@ -0,0 +1,66 @@ +import warnings +from typing import Literal + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + +MODEL_CONTEXT_WINDOWS = { + "grok-3-fast": 131_072, + "grok-3": 131_072, + "grok-3-mini": 131_072, + "grok-3-mini-fast": 131_072, + "grok-4-0709": 256_000, +} + + +class XAIProvider(OpenAIProvider): + """https://docs.x.ai/docs/api-reference""" + + provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the xAI/Grok API.") + base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") + + def get_model_context_window_size(self, model_name: str) -> int | None: + # xAI doesn't return context window in the model listing, + # this is hardcoded from https://docs.x.ai/docs/models + return MODEL_CONTEXT_WINDOWS.get(model_name) + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + + data = response.get("data", response) + + configs = [] + for model in data: + assert "id" in model, f"xAI/Grok model missing 'id' field: {model}" + model_name = model["id"] + + # In case xAI starts supporting it in the future: + if "context_length" in model: + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + warnings.warn(f"Couldn't find context window size for model {model_name}") + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="xai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 3c1cbdfb..208613ec 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -407,9 +407,10 @@ def start_server( address=host or "127.0.0.1", # Note granian address must be an ip address port=port or REST_DEFAULT_PORT, workers=settings.uvicorn_workers, - # threads= + # runtime_blocking_threads= + # runtime_threads= reload=reload or settings.uvicorn_reload, - reload_ignore_patterns=["openapi_letta.json"], + reload_paths=["letta/"], reload_ignore_worker_failure=True, reload_tick=4000, # set to 4s to prevent crashing on weird state # log_level="info" @@ -451,7 +452,7 @@ def start_server( # runtime_blocking_threads= # runtime_threads= reload=reload or settings.uvicorn_reload, - reload_paths=["../letta/"], + reload_paths=["letta/"], reload_ignore_worker_failure=True, reload_tick=4000, # set to 4s to prevent crashing on weird state # log_level="info" diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 66cb36ec..6f3f30d2 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -71,12 +71,12 @@ async def modify_provider( @router.get("/check", response_model=None, operation_id="check_provider") -def check_provider( +async def check_provider( request: ProviderCheck = Body(...), server: "SyncServer" = Depends(get_letta_server), ): try: - server.provider_manager.check_provider_api_key(provider_check=request) + await server.provider_manager.check_provider_api_key(provider_check=request) return JSONResponse( status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={request.provider_type.value}"} ) diff --git a/letta/server/server.py b/letta/server/server.py index fbc6db1e..cda9f8e0 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -68,8 +68,6 @@ from letta.schemas.providers import ( OpenAIProvider, Provider, TogetherProvider, - VLLMChatCompletionsProvider, - VLLMCompletionsProvider, XAIProvider, ) from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate, SandboxType @@ -360,7 +358,7 @@ class SyncServer(Server): ) ) # NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup - # see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html + # see: https://docs.vllm.ai/en/stable/features/tool_calling.html # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" self._enabled_providers.append( VLLMChatCompletionsProvider( @@ -460,7 +458,7 @@ class SyncServer(Server): # Determine whether or not to token stream based on the capability of the interface token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False - logger.debug(f"Starting agent step") + logger.debug("Starting agent step") if interface: metadata = interface.metadata if hasattr(interface, "metadata") else None else: @@ -534,7 +532,7 @@ class SyncServer(Server): letta_agent.interface.print_messages_raw(letta_agent.messages) elif command.lower() == "memory": - ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}" + ret_str = "\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}" return ret_str elif command.lower() == "pop" or command.lower().startswith("pop "): @@ -554,7 +552,7 @@ class SyncServer(Server): elif command.lower() == "retry": # TODO this needs to also modify the persistence manager - logger.debug(f"Retrying for another answer") + logger.debug("Retrying for another answer") while len(letta_agent.messages) > 0: if letta_agent.messages[-1].get("role") == "user": # we want to pop up to the last user message and send it again @@ -770,6 +768,7 @@ class SyncServer(Server): self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs) return self._embedding_config_cache[key] + # @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs)) @trace_method async def get_cached_embedding_config_async(self, actor: User, **kwargs): key = make_key(**kwargs) @@ -782,9 +781,9 @@ class SyncServer(Server): self, request: CreateAgent, actor: User, - # interface - interface: Union[AgentInterface, None] = None, + interface: AgentInterface | None = None, ) -> AgentState: + warnings.warn("This method is deprecated, use create_agent_async where possible.", DeprecationWarning, stacklevel=2) if request.llm_config is None: if request.model is None: raise ValueError("Must specify either model or llm_config in request") @@ -1320,7 +1319,6 @@ class SyncServer(Server): # TODO: delete data from agent passage stores (?) async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: - # update job job = await self.job_manager.get_job_by_id_async(job_id, actor=actor) job.status = JobStatus.running @@ -1564,7 +1562,6 @@ class SyncServer(Server): # Add extra metadata to the sources sources_with_metadata = [] for source in sources: - # count number of passages num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id) @@ -2118,7 +2115,6 @@ class SyncServer(Server): mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) if os.path.exists(mcp_config_path): with open(mcp_config_path, "r") as f: - try: mcp_config = json.load(f) except Exception as e: @@ -2130,7 +2126,6 @@ class SyncServer(Server): # with the value being the schema from StdioServerParameters if MCP_CONFIG_TOPLEVEL_KEY in mcp_config: for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items(): - # No support for duplicate server names if server_name in mcp_server_list: logger.error(f"Duplicate MCP server name found (skipping): {server_name}") @@ -2301,7 +2296,6 @@ class SyncServer(Server): # For streaming response try: - # TODO: move this logic into server.py # Get the generator object off of the agent's streaming interface @@ -2441,9 +2435,9 @@ class SyncServer(Server): if not stream_steps and stream_tokens: raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'") - group = self.group_manager.retrieve_group(group_id=group_id, actor=actor) + group = await self.group_manager.retrieve_group_async(group_id=group_id, actor=actor) agent_state_id = group.manager_agent_id or (group.agent_ids[0] if len(group.agent_ids) > 0 else None) - agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_state_id, actor=actor) if agent_state_id else None + agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_state_id, actor=actor) if agent_state_id else None letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor) llm_config = letta_multi_agent.agent_state.llm_config diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index f972eb6a..60de7c71 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -58,18 +58,30 @@ class FileProcessor: for page in ocr_response.pages: chunks = text_chunker.chunk_text(page) if not chunks: - log_event("file_processor.chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)}) + log_event( + "file_processor.chunking_failed", + { + "filename": filename, + "page_index": ocr_response.pages.index(page), + }, + ) raise ValueError("No chunks created from text") all_chunks.extend(chunks) all_passages = await self.embedder.generate_embedded_passages( - file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor + file_id=file_metadata.id, + source_id=source_id, + chunks=all_chunks, + actor=self.actor, ) return all_passages except Exception as e: logger.warning(f"Failed to chunk/embed with file-specific chunker for {filename}: {str(e)}. Retrying with default chunker.") - log_event("file_processor.embedding_failed_retrying", {"filename": filename, "error": str(e), "error_type": type(e).__name__}) + log_event( + "file_processor.embedding_failed_retrying", + {"filename": filename, "error": str(e), "error_type": type(e).__name__}, + ) # Retry with default chunker try: @@ -80,31 +92,50 @@ class FileProcessor: chunks = text_chunker.default_chunk_text(page) if not chunks: log_event( - "file_processor.default_chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)} + "file_processor.default_chunking_failed", + { + "filename": filename, + "page_index": ocr_response.pages.index(page), + }, ) raise ValueError("No chunks created from text with default chunker") all_chunks.extend(chunks) all_passages = await self.embedder.generate_embedded_passages( - file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor + file_id=file_metadata.id, + source_id=source_id, + chunks=all_chunks, + actor=self.actor, ) logger.info(f"Successfully generated passages with default chunker for {filename}") - log_event("file_processor.default_chunking_success", {"filename": filename, "total_chunks": len(all_chunks)}) + log_event( + "file_processor.default_chunking_success", + {"filename": filename, "total_chunks": len(all_chunks)}, + ) return all_passages except Exception as fallback_error: logger.error("Default chunking also failed for %s: %s", filename, fallback_error) log_event( "file_processor.default_chunking_also_failed", - {"filename": filename, "fallback_error": str(fallback_error), "fallback_error_type": type(fallback_error).__name__}, + { + "filename": filename, + "fallback_error": str(fallback_error), + "fallback_error_type": type(fallback_error).__name__, + }, ) raise fallback_error # TODO: Factor this function out of SyncServer @trace_method async def process( - self, server: SyncServer, agent_states: List[AgentState], source_id: str, content: bytes, file_metadata: FileMetadata - ) -> List[Passage]: + self, + server: SyncServer, + agent_states: list[AgentState], + source_id: str, + content: bytes, + file_metadata: FileMetadata, + ) -> list[Passage]: filename = file_metadata.file_name # Create file as early as possible with no content @@ -170,18 +201,28 @@ class FileProcessor: raise ValueError("No text extracted from PDF") logger.info("Chunking extracted text") - log_event("file_processor.chunking_started", {"filename": filename, "pages_to_process": len(ocr_response.pages)}) + log_event( + "file_processor.chunking_started", + {"filename": filename, "pages_to_process": len(ocr_response.pages)}, + ) # Chunk and embed with fallback logic all_passages = await self._chunk_and_embed_with_fallback( - file_metadata=file_metadata, ocr_response=ocr_response, source_id=source_id + file_metadata=file_metadata, + ocr_response=ocr_response, + source_id=source_id, ) if not self.using_pinecone: all_passages = await self.passage_manager.create_many_source_passages_async( - passages=all_passages, file_metadata=file_metadata, actor=self.actor + passages=all_passages, + file_metadata=file_metadata, + actor=self.actor, + ) + log_event( + "file_processor.passages_created", + {"filename": filename, "total_passages": len(all_passages)}, ) - log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)}) logger.info(f"Successfully processed {filename}: {len(all_passages)} passages") log_event( @@ -197,11 +238,16 @@ class FileProcessor: # update job status if not self.using_pinecone: await self.file_manager.update_file_status( - file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED + file_id=file_metadata.id, + actor=self.actor, + processing_status=FileProcessingStatus.COMPLETED, ) else: await self.file_manager.update_file_status( - file_id=file_metadata.id, actor=self.actor, total_chunks=len(all_passages), chunks_embedded=0 + file_id=file_metadata.id, + actor=self.actor, + total_chunks=len(all_passages), + chunks_embedded=0, ) return all_passages @@ -310,7 +356,10 @@ class FileProcessor: }, ) await self.file_manager.update_file_status( - file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e) + file_id=file_metadata.id, + actor=self.actor, + processing_status=FileProcessingStatus.ERROR, + error_message=str(e), ) return [] diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 96684271..91e15d50 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -220,6 +220,13 @@ class GroupManager: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) group.hard_delete(session) + @enforce_types + @trace_method + async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None: + async with db_registry.async_session() as session: + group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) + await group.hard_delete_async(session) + @enforce_types @trace_method def list_group_messages( diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 610ffb2e..675ff763 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -207,7 +207,7 @@ class ProviderManager: @enforce_types @trace_method - def check_provider_api_key(self, provider_check: ProviderCheck) -> None: + async def check_provider_api_key(self, provider_check: ProviderCheck) -> None: provider = PydanticProvider( name=provider_check.provider_type.value, provider_type=provider_check.provider_type, @@ -221,4 +221,4 @@ class ProviderManager: if not provider.api_key: raise ValueError("API key is required") - provider.check_api_key() + await provider.check_api_key() diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index d3528597..d1f599aa 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -1,4 +1,4 @@ -import asyncio +import os import secrets import string import uuid @@ -18,6 +18,7 @@ from letta.schemas.organization import Organization from letta.schemas.pip_requirement import PipRequirement from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate from letta.schemas.user import User +from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager @@ -32,6 +33,48 @@ namespace = uuid.NAMESPACE_DNS org_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-org")) user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) +# Set environment variable immediately to prevent pooling issues +os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true" + +# Recreate settings instance to pick up the environment variable +import letta.settings + +# Force settings reload after setting environment variable +from letta.settings import Settings + +letta.settings.settings = Settings() + + +# Disable SQLAlchemy connection pooling for tests to prevent event loop issues +@pytest.fixture(scope="session", autouse=True) +def disable_db_pooling_for_tests(): + """Disable database connection pooling for the entire test session.""" + # Environment variable is already set above and settings reloaded + yield + # Clean up environment variable after tests + if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ: + del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] + + +@pytest.fixture(autouse=True) +async def cleanup_db_connections(): + """Cleanup database connections after each test.""" + yield + + # Dispose async engines in the current event loop + try: + if hasattr(db_registry, "_async_engines"): + for engine in db_registry._async_engines.values(): + if engine: + await engine.dispose() + # Reset async initialization to force fresh connections + db_registry._initialized["async"] = False + db_registry._async_engines.clear() + db_registry._async_session_factories.clear() + except Exception as e: + # Log the error but don't fail the test + print(f"Warning: Failed to cleanup database connections: {e}") + # Fixtures @pytest.fixture(scope="module") @@ -50,14 +93,14 @@ def server(): @pytest.fixture(autouse=True) -def clear_tables(): +async def clear_tables(): """Fixture to clear the organization table before each test.""" - from letta.server.db import db_context + from letta.server.db import db_registry - with db_context() as session: - session.execute(delete(SandboxEnvironmentVariable)) - session.execute(delete(SandboxConfig)) - session.commit() # Commit the deletion + async with db_registry.async_session() as session: + await session.execute(delete(SandboxEnvironmentVariable)) + await session.execute(delete(SandboxConfig)) + await session.commit() # Commit the deletion @pytest.fixture @@ -208,9 +251,9 @@ def external_codebase_tool(test_user): @pytest.fixture -def agent_state(server): +async def agent_state(server): actor = server.user_manager.get_user_or_default() - agent_state = server.create_agent( + agent_state = await server.create_agent_async( CreateAgent( memory_blocks=[ CreateBlock( @@ -471,12 +514,7 @@ def async_complex_tool(test_user): yield tool -@pytest.fixture(scope="session") -def event_loop(request): - """Create an instance of the default event loop for each test case.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() +# Removed custom event_loop fixture to avoid conflicts with pytest-asyncio # Local sandbox tests @@ -484,7 +522,7 @@ def event_loop(request): @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user, event_loop): +async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user): args = {"x": 10, "y": 5} # Mock and assert correct pathway was invoked @@ -501,7 +539,7 @@ async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, tes @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state, event_loop): +async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state): args = {} sandbox = AsyncToolSandboxLocal(clear_core_memory_tool.name, args, user=test_user) result = await sandbox.run(agent_state=agent_state) @@ -513,7 +551,7 @@ async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memor @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user, event_loop): +async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user): sandbox = AsyncToolSandboxLocal(list_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.func_return) == 5 @@ -521,7 +559,7 @@ async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_u @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user, event_loop): +async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user): manager = SandboxConfigManager() sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump()) @@ -540,7 +578,7 @@ async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user, e @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user, event_loop): +async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user): manager = SandboxConfigManager() key = "secret_word" sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") @@ -562,7 +600,7 @@ async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, ag @pytest.mark.asyncio @pytest.mark.local_sandbox async def test_local_sandbox_external_codebase_with_venv( - disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user, event_loop + disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user ): args = {"percentage": 10} sandbox = AsyncToolSandboxLocal(external_codebase_tool.name, args, user=test_user) @@ -574,7 +612,7 @@ async def test_local_sandbox_external_codebase_with_venv( @pytest.mark.asyncio @pytest.mark.local_sandbox async def test_local_sandbox_with_venv_and_warnings_does_not_error( - disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user, event_loop + disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user ): sandbox = AsyncToolSandboxLocal(get_warning_tool.name, {}, user=test_user) result = await sandbox.run() @@ -583,7 +621,7 @@ async def test_local_sandbox_with_venv_and_warnings_does_not_error( @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user, event_loop): +async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user): sandbox = AsyncToolSandboxLocal(always_err_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.stdout) != 0 @@ -594,7 +632,7 @@ async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_s @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user, event_loop): +async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user): manager = SandboxConfigManager() config_create = SandboxConfigCreate( config=LocalSandboxConfig(use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump() @@ -614,7 +652,7 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop): +async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user): """Test that local sandbox installs tool-specific pip requirements.""" manager = SandboxConfigManager() sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") @@ -634,7 +672,7 @@ async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, too @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop): +async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user): """Test that local sandbox installs both sandbox and tool pip requirements.""" manager = SandboxConfigManager() sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") @@ -658,7 +696,7 @@ async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, to @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop): +async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user): manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=LocalSandboxConfig(use_venv=True).model_dump()) config = manager.create_or_update_sandbox_config(config_create, test_user) @@ -689,7 +727,7 @@ async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_ @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user, event_loop): +async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user): args = {"x": 10, "y": 5} # Mock and assert correct pathway was invoked @@ -706,7 +744,7 @@ async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user, event_loop): +async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user): manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump()) config = manager.create_or_update_sandbox_config(config_create, test_user) @@ -726,7 +764,7 @@ async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_ @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state, event_loop): +async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state): sandbox = AsyncToolSandboxE2B(clear_core_memory_tool.name, {}, user=test_user) result = await sandbox.run(agent_state=agent_state) assert result.agent_state.memory.get_block("human").value == "" @@ -736,7 +774,7 @@ async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user, event_loop): +async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user): manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) config = manager.create_or_update_sandbox_config(config_create, test_user) @@ -760,7 +798,7 @@ async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user, event_loop): +async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user): manager = SandboxConfigManager() key = "secret_word" wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) @@ -784,7 +822,7 @@ async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, age @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user, event_loop): +async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user): sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.func_return) == 5 @@ -792,7 +830,7 @@ async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_us @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop): +async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user): """Test that E2B sandbox installs tool-specific pip requirements.""" manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) @@ -809,7 +847,7 @@ async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop): +async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user): """Test that E2B sandbox installs both sandbox and tool pip requirements.""" manager = SandboxConfigManager() @@ -829,7 +867,7 @@ async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, too @pytest.mark.asyncio @pytest.mark.e2b_sandbox async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling( - check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user, event_loop + check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user ): """Test that E2B sandbox provides informative error messages for broken tool pip requirements.""" manager = SandboxConfigManager() @@ -894,7 +932,7 @@ def test_async_template_selection(add_integers_tool, async_add_integers_tool, te @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async_add_integers_tool, test_user, event_loop): +async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async_add_integers_tool, test_user): """Test that async functions execute correctly in local sandbox""" args = {"x": 15, "y": 25} @@ -905,7 +943,7 @@ async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_add_integers_tool, test_user, event_loop): +async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_add_integers_tool, test_user): """Test that async functions execute correctly in E2B sandbox""" args = {"x": 20, "y": 30} @@ -916,7 +954,7 @@ async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_ @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, async_complex_tool, test_user, event_loop): +async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, async_complex_tool, test_user): """Test complex async computation with multiple awaits in local sandbox""" args = {"iterations": 2} @@ -932,7 +970,7 @@ async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, asyn @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async_complex_tool, test_user, event_loop): +async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async_complex_tool, test_user): """Test complex async computation with multiple awaits in E2B sandbox""" args = {"iterations": 2} @@ -949,7 +987,7 @@ async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_tool, test_user, event_loop): +async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_tool, test_user): """Test async function returning list in local sandbox""" sandbox = AsyncToolSandboxLocal(async_list_tool.name, {}, user=test_user) result = await sandbox.run() @@ -958,7 +996,7 @@ async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_t @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_tool, test_user, event_loop): +async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_tool, test_user): """Test async function returning list in E2B sandbox""" sandbox = AsyncToolSandboxE2B(async_list_tool.name, {}, user=test_user) result = await sandbox.run() @@ -967,7 +1005,7 @@ async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_to @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_env_tool, test_user, event_loop): +async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_env_tool, test_user): """Test async function with environment variables in local sandbox""" manager = SandboxConfigManager() @@ -991,7 +1029,7 @@ async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_ @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_env_tool, test_user, event_loop): +async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_env_tool, test_user): """Test async function with environment variables in E2B sandbox""" manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) @@ -1012,7 +1050,7 @@ async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_e @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_stateful_tool, test_user, agent_state, event_loop): +async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_stateful_tool, test_user, agent_state): """Test async function with agent state in local sandbox""" sandbox = AsyncToolSandboxLocal(async_stateful_tool.name, {}, user=test_user) result = await sandbox.run(agent_state=agent_state) @@ -1025,7 +1063,7 @@ async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_s @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_stateful_tool, test_user, agent_state, event_loop): +async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_stateful_tool, test_user, agent_state): """Test async function with agent state in E2B sandbox""" sandbox = AsyncToolSandboxE2B(async_stateful_tool.name, {}, user=test_user) result = await sandbox.run(agent_state=agent_state) @@ -1037,7 +1075,7 @@ async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_st @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_error_tool, test_user, event_loop): +async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_error_tool, test_user): """Test async function error handling in local sandbox""" sandbox = AsyncToolSandboxLocal(async_error_tool.name, {}, user=test_user) result = await sandbox.run() @@ -1051,7 +1089,7 @@ async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_err @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_error_tool, test_user, event_loop): +async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_error_tool, test_user): """Test async function error handling in E2B sandbox""" sandbox = AsyncToolSandboxE2B(async_error_tool.name, {}, user=test_user) result = await sandbox.run() @@ -1065,7 +1103,7 @@ async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_erro @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_env_tool, agent_state, test_user, event_loop): +async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_env_tool, agent_state, test_user): """Test async function with per-agent environment variables in local sandbox""" manager = SandboxConfigManager() key = "secret_word" @@ -1087,7 +1125,7 @@ async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_ @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_env_tool, agent_state, test_user, event_loop): +async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_env_tool, agent_state, test_user): """Test async function with per-agent environment variables in E2B sandbox""" manager = SandboxConfigManager() key = "secret_word" diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 2dda36a9..b4fa4eb9 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -1,8 +1,8 @@ +import os + import pytest -from sqlalchemy import delete from letta.config import LettaConfig -from letta.orm import Provider, ProviderTrace, Step from letta.schemas.agent import CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.group import ( @@ -19,6 +19,37 @@ from letta.server.db import db_registry from letta.server.server import SyncServer +# Disable SQLAlchemy connection pooling for tests to prevent event loop issues +@pytest.fixture(scope="session", autouse=True) +def disable_db_pooling_for_tests(): + """Disable database connection pooling for the entire test session.""" + os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true" + yield + # Clean up environment variable after tests + if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ: + del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] + + +@pytest.fixture(autouse=True) +async def cleanup_db_connections(): + """Cleanup database connections after each test.""" + yield + + # Dispose async engines in the current event loop + try: + if hasattr(db_registry, "_async_engines"): + for engine in db_registry._async_engines.values(): + if engine: + await engine.dispose() + # Reset async initialization to force fresh connections + db_registry._initialized["async"] = False + db_registry._async_engines.clear() + db_registry._async_session_factories.clear() + except Exception as e: + # Log the error but don't fail the test + print(f"Warning: Failed to cleanup database connections: {e}") + + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() @@ -30,33 +61,21 @@ def server(): return server -@pytest.fixture(scope="module") -def org_id(server): - org = server.organization_manager.create_default_organization() - - yield org.id - - # cleanup - with db_registry.session() as session: - session.execute(delete(ProviderTrace)) - session.execute(delete(Step)) - session.execute(delete(Provider)) - session.commit() - server.organization_manager.delete_organization_by_id(org.id) +@pytest.fixture +async def default_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + yield await server.organization_manager.create_default_organization_async() -@pytest.fixture(scope="module") -def actor(server, org_id): - user = server.user_manager.create_default_user() - yield user - - # cleanup - server.user_manager.delete_user_by_id(user.id) +@pytest.fixture +async def default_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + yield await server.user_manager.create_default_actor_async(org_id=default_organization.id) -@pytest.fixture(scope="module") -def participant_agents(server, actor): - agent_fred = server.create_agent( +@pytest.fixture +async def four_participant_agents(server, default_user): + agent_fred = await server.create_agent_async( request=CreateAgent( name="fred", memory_blocks=[ @@ -68,9 +87,9 @@ def participant_agents(server, actor): model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) - agent_velma = server.create_agent( + agent_velma = await server.create_agent_async( request=CreateAgent( name="velma", memory_blocks=[ @@ -82,9 +101,9 @@ def participant_agents(server, actor): model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) - agent_daphne = server.create_agent( + agent_daphne = await server.create_agent_async( request=CreateAgent( name="daphne", memory_blocks=[ @@ -96,9 +115,9 @@ def participant_agents(server, actor): model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) - agent_shaggy = server.create_agent( + agent_shaggy = await server.create_agent_async( request=CreateAgent( name="shaggy", memory_blocks=[ @@ -110,20 +129,14 @@ def participant_agents(server, actor): model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) yield [agent_fred, agent_velma, agent_daphne, agent_shaggy] - # cleanup - server.agent_manager.delete_agent(agent_fred.id, actor=actor) - server.agent_manager.delete_agent(agent_velma.id, actor=actor) - server.agent_manager.delete_agent(agent_daphne.id, actor=actor) - server.agent_manager.delete_agent(agent_shaggy.id, actor=actor) - -@pytest.fixture(scope="module") -def manager_agent(server, actor): - agent_scooby = server.create_agent( +@pytest.fixture +async def manager_agent(server, default_user): + agent_scooby = await server.create_agent_async( request=CreateAgent( name="scooby", memory_blocks=[ @@ -139,27 +152,24 @@ def manager_agent(server, actor): model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) yield agent_scooby - # cleanup - server.agent_manager.delete_agent(agent_scooby.id, actor=actor) - -@pytest.mark.asyncio(loop_scope="module") -async def test_empty_group(server, actor): - group = server.group_manager.create_group( +@pytest.mark.asyncio +async def test_empty_group(server, default_user): + group = await server.group_manager.create_group_async( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", agent_ids=[], ), - actor=actor, + actor=default_user, ) with pytest.raises(ValueError, match="Empty group"): await server.send_group_message_to_agent( group_id=group.id, - actor=actor, + actor=default_user, input_messages=[ MessageCreate( role="user", @@ -169,17 +179,17 @@ async def test_empty_group(server, actor): stream_steps=False, stream_tokens=False, ) - server.group_manager.delete_group(group_id=group.id, actor=actor) + await server.group_manager.delete_group_async(group_id=group.id, actor=default_user) -@pytest.mark.asyncio(loop_scope="module") -async def test_modify_group_pattern(server, actor, participant_agents, manager_agent): - group = server.group_manager.create_group( +@pytest.mark.asyncio +async def test_modify_group_pattern(server, default_user, four_participant_agents, manager_agent): + group = await server.group_manager.create_group_async( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=[agent.id for agent in participant_agents], + agent_ids=[agent.id for agent in four_participant_agents], ), - actor=actor, + actor=default_user, ) with pytest.raises(ValueError, match="Cannot change group pattern"): await server.group_manager.modify_group_async( @@ -190,64 +200,64 @@ async def test_modify_group_pattern(server, actor, participant_agents, manager_a manager_agent_id=manager_agent.id, ), ), - actor=actor, + actor=default_user, ) - server.group_manager.delete_group(group_id=group.id, actor=actor) + await server.group_manager.delete_group_async(group_id=group.id, actor=default_user) -@pytest.mark.asyncio(loop_scope="module") -async def test_list_agent_groups(server, actor, participant_agents): - group_a = server.group_manager.create_group( +@pytest.mark.asyncio +async def test_list_agent_groups(server, default_user, four_participant_agents): + group_a = await server.group_manager.create_group_async( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=[agent.id for agent in participant_agents], + agent_ids=[agent.id for agent in four_participant_agents], ), - actor=actor, + actor=default_user, ) - group_b = server.group_manager.create_group( + group_b = await server.group_manager.create_group_async( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=[participant_agents[0].id], + agent_ids=[four_participant_agents[0].id], ), - actor=actor, + actor=default_user, ) - agent_a_groups = server.agent_manager.list_groups(agent_id=participant_agents[0].id, actor=actor) + agent_a_groups = server.agent_manager.list_groups(agent_id=four_participant_agents[0].id, actor=default_user) assert sorted([group.id for group in agent_a_groups]) == sorted([group_a.id, group_b.id]) - agent_b_groups = server.agent_manager.list_groups(agent_id=participant_agents[1].id, actor=actor) + agent_b_groups = server.agent_manager.list_groups(agent_id=four_participant_agents[1].id, actor=default_user) assert [group.id for group in agent_b_groups] == [group_a.id] - server.group_manager.delete_group(group_id=group_a.id, actor=actor) - server.group_manager.delete_group(group_id=group_b.id, actor=actor) + await server.group_manager.delete_group_async(group_id=group_a.id, actor=default_user) + await server.group_manager.delete_group_async(group_id=group_b.id, actor=default_user) -@pytest.mark.asyncio(loop_scope="module") -async def test_round_robin(server, actor, participant_agents): +@pytest.mark.asyncio +async def test_round_robin(server, default_user, four_participant_agents): description = ( "This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries." ) - group = server.group_manager.create_group( + group = await server.group_manager.create_group_async( group=GroupCreate( description=description, - agent_ids=[agent.id for agent in participant_agents], + agent_ids=[agent.id for agent in four_participant_agents], ), - actor=actor, + actor=default_user, ) # verify group creation assert group.manager_type == ManagerType.round_robin assert group.description == description - assert group.agent_ids == [agent.id for agent in participant_agents] + assert group.agent_ids == [agent.id for agent in four_participant_agents] assert group.max_turns is None assert group.manager_agent_id is None assert group.termination_token is None try: - server.group_manager.reset_messages(group_id=group.id, actor=actor) + server.group_manager.reset_messages(group_id=group.id, actor=default_user) response = await server.send_group_message_to_agent( group_id=group.id, - actor=actor, + actor=default_user, input_messages=[ MessageCreate( role="user", @@ -261,11 +271,11 @@ async def test_round_robin(server, actor, participant_agents): assert len(response.messages) == response.usage.step_count * 2 for i, message in enumerate(response.messages): assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message" - assert message.name == participant_agents[i // 2].name + assert message.name == four_participant_agents[i // 2].name for agent_id in group.agent_ids: agent_messages = server.get_agent_recall( - user_id=actor.id, + user_id=default_user.id, agent_id=agent_id, group_id=group.id, reverse=True, @@ -276,7 +286,7 @@ async def test_round_robin(server, actor, participant_agents): # TODO: filter this to return a clean conversation history messages = server.group_manager.list_group_messages( group_id=group.id, - actor=actor, + actor=default_user, ) assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids) @@ -284,25 +294,25 @@ async def test_round_robin(server, actor, participant_agents): group = await server.group_manager.modify_group_async( group_id=group.id, group_update=GroupUpdate( - agent_ids=[agent.id for agent in participant_agents][::-1], + agent_ids=[agent.id for agent in four_participant_agents][::-1], manager_config=RoundRobinManagerUpdate( max_turns=max_turns, ), ), - actor=actor, + actor=default_user, ) assert group.manager_type == ManagerType.round_robin assert group.description == description - assert group.agent_ids == [agent.id for agent in participant_agents][::-1] + assert group.agent_ids == [agent.id for agent in four_participant_agents][::-1] assert group.max_turns == max_turns assert group.manager_agent_id is None assert group.termination_token is None - server.group_manager.reset_messages(group_id=group.id, actor=actor) + server.group_manager.reset_messages(group_id=group.id, actor=default_user) response = await server.send_group_message_to_agent( group_id=group.id, - actor=actor, + actor=default_user, input_messages=[ MessageCreate( role="user", @@ -317,11 +327,11 @@ async def test_round_robin(server, actor, participant_agents): for i, message in enumerate(response.messages): assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message" - assert message.name == participant_agents[::-1][i // 2].name + assert message.name == four_participant_agents[::-1][i // 2].name for i in range(len(group.agent_ids)): agent_messages = server.get_agent_recall( - user_id=actor.id, + user_id=default_user.id, agent_id=group.agent_ids[i], group_id=group.id, reverse=True, @@ -331,12 +341,12 @@ async def test_round_robin(server, actor, participant_agents): assert len(agent_messages) == expected_message_count finally: - server.group_manager.delete_group(group_id=group.id, actor=actor) + await server.group_manager.delete_group_async(group_id=group.id, actor=default_user) -@pytest.mark.asyncio(loop_scope="module") -async def test_supervisor(server, actor, participant_agents): - agent_scrappy = server.create_agent( +@pytest.mark.asyncio +async def test_supervisor(server, default_user, four_participant_agents): + agent_scrappy = await server.create_agent_async( request=CreateAgent( name="shaggy", memory_blocks=[ @@ -352,23 +362,23 @@ async def test_supervisor(server, actor, participant_agents): model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) - group = server.group_manager.create_group( + group = await server.group_manager.create_group_async( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=[agent.id for agent in participant_agents], + agent_ids=[agent.id for agent in four_participant_agents], manager_config=SupervisorManager( manager_agent_id=agent_scrappy.id, ), ), - actor=actor, + actor=default_user, ) try: response = await server.send_group_message_to_agent( group_id=group.id, - actor=actor, + actor=default_user, input_messages=[ MessageCreate( role="user", @@ -388,32 +398,33 @@ async def test_supervisor(server, actor, participant_agents): and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group" ) assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len( - participant_agents + four_participant_agents ) assert response.messages[3].message_type == "reasoning_message" assert response.messages[4].message_type == "assistant_message" finally: - server.group_manager.delete_group(group_id=group.id, actor=actor) - server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=actor) + await server.group_manager.delete_group_async(group_id=group.id, actor=default_user) + server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=default_user) -@pytest.mark.asyncio(loop_scope="module") -async def test_dynamic_group_chat(server, actor, manager_agent, participant_agents): +@pytest.mark.asyncio +@pytest.mark.flaky(max_runs=2) +async def test_dynamic_group_chat(server, default_user, manager_agent, four_participant_agents): description = ( "This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries." ) # error on duplicate agent in participant list with pytest.raises(ValueError, match="Duplicate agent ids"): - server.group_manager.create_group( + await server.group_manager.create_group_async( group=GroupCreate( description=description, - agent_ids=[agent.id for agent in participant_agents] + [participant_agents[0].id], + agent_ids=[agent.id for agent in four_participant_agents] + [four_participant_agents[0].id], manager_config=DynamicManager( manager_agent_id=manager_agent.id, ), ), - actor=actor, + actor=default_user, ) # error on duplicate agent names duplicate_agent_shaggy = server.create_agent( @@ -422,43 +433,43 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", ), - actor=actor, + actor=default_user, ) with pytest.raises(ValueError, match="Duplicate agent names"): - server.group_manager.create_group( + await server.group_manager.create_group_async( group=GroupCreate( description=description, - agent_ids=[agent.id for agent in participant_agents] + [duplicate_agent_shaggy.id], + agent_ids=[agent.id for agent in four_participant_agents] + [duplicate_agent_shaggy.id], manager_config=DynamicManager( manager_agent_id=manager_agent.id, ), ), - actor=actor, + actor=default_user, ) - server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=actor) + server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=default_user) - group = server.group_manager.create_group( + group = await server.group_manager.create_group_async( group=GroupCreate( description=description, - agent_ids=[agent.id for agent in participant_agents], + agent_ids=[agent.id for agent in four_participant_agents], manager_config=DynamicManager( manager_agent_id=manager_agent.id, ), ), - actor=actor, + actor=default_user, ) try: response = await server.send_group_message_to_agent( group_id=group.id, - actor=actor, + actor=default_user, input_messages=[ MessageCreate(role="user", content="what is everyone up to for the holidays?"), ], stream_steps=False, stream_tokens=False, ) - assert response.usage.step_count == len(participant_agents) * 2 + assert response.usage.step_count == len(four_participant_agents) * 2 assert len(response.messages) == response.usage.step_count * 2 finally: - server.group_manager.delete_group(group_id=group.id, actor=actor) + await server.group_manager.delete_group_async(group_id=group.id, actor=default_user) diff --git a/tests/test_providers.py b/tests/test_providers.py index 439d3417..71230218 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -10,7 +10,7 @@ from letta.schemas.providers import ( OllamaProvider, OpenAIProvider, TogetherProvider, - VLLMChatCompletionsProvider, + VLLMProvider, ) from letta.settings import model_settings @@ -46,25 +46,8 @@ async def test_openai_async(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_deepseek(): - provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - -def test_anthropic(): - provider = AnthropicProvider( - name="anthropic", - api_key=model_settings.anthropic_api_key, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - @pytest.mark.asyncio -async def test_anthropic_async(): +async def test_anthropic(): provider = AnthropicProvider( name="anthropic", api_key=model_settings.anthropic_api_key, @@ -74,67 +57,8 @@ async def test_anthropic_async(): assert models[0].handle == f"{provider.name}/{models[0].model}" -def test_groq(): - provider = GroqProvider( - name="groq", - api_key=model_settings.groq_api_key, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - -def test_azure(): - provider = AzureProvider( - name="azure", - api_key=model_settings.azure_api_key, - base_url=model_settings.azure_base_url, - api_version=model_settings.azure_api_version, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" - - -@pytest.mark.skipif(model_settings.ollama_base_url is None, reason="Only run if OLLAMA_BASE_URL is set.") -def test_ollama(): - provider = OllamaProvider( - name="ollama", - base_url=model_settings.ollama_base_url, - api_key=None, - default_prompt_formatter=model_settings.default_prompt_formatter, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" - - -def test_googleai(): - api_key = model_settings.gemini_api_key - assert api_key is not None - provider = GoogleAIProvider( - name="google_ai", - api_key=api_key, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" - - @pytest.mark.asyncio -async def test_googleai_async(): +async def test_googleai(): api_key = model_settings.gemini_api_key assert api_key is not None provider = GoogleAIProvider( @@ -150,42 +74,64 @@ async def test_googleai_async(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_google_vertex(): +@pytest.mark.asyncio +async def test_google_vertex(): provider = GoogleVertexProvider( name="google_vertex", google_cloud_project=model_settings.google_cloud_project, google_cloud_location=model_settings.google_cloud_location, ) - models = provider.list_llm_models() + models = await provider.list_llm_models_async() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" - embedding_models = provider.list_embedding_models() + embedding_models = await provider.list_embedding_models_async() assert len(embedding_models) > 0 assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_together(): - provider = TogetherProvider( - name="together", - api_key=model_settings.together_api_key, - default_prompt_formatter=model_settings.default_prompt_formatter, - ) - models = provider.list_llm_models() - assert len(models) > 0 - # Handle may be different from raw model name due to LLM_HANDLE_OVERRIDES - assert models[0].handle.startswith(f"{provider.name}/") - # Verify the handle is properly constructed via get_handle method - assert models[0].handle == provider.get_handle(models[0].model) - - # TODO: We don't have embedding models on together for CI - # embedding_models = provider.list_embedding_models() - # assert len(embedding_models) > 0 - # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" - - +@pytest.mark.skipif(model_settings.deepseek_api_key is None, reason="Only run if DEEPSEEK_API_KEY is set.") @pytest.mark.asyncio -async def test_together_async(): +async def test_deepseek(): + provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + +@pytest.mark.skipif(model_settings.groq_api_key is None, reason="Only run if GROQ_API_KEY is set.") +@pytest.mark.asyncio +async def test_groq(): + provider = GroqProvider( + name="groq", + api_key=model_settings.groq_api_key, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + +@pytest.mark.skipif(model_settings.azure_api_key is None, reason="Only run if AZURE_API_KEY is set.") +@pytest.mark.asyncio +async def test_azure(): + provider = AzureProvider( + name="azure", + api_key=model_settings.azure_api_key, + base_url=model_settings.azure_base_url, + api_version=model_settings.azure_api_version, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = await provider.list_embedding_models_async() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +@pytest.mark.skipif(model_settings.together_api_key is None, reason="Only run if TOGETHER_API_KEY is set.") +@pytest.mark.asyncio +async def test_together(): provider = TogetherProvider( name="together", api_key=model_settings.together_api_key, @@ -204,6 +150,33 @@ async def test_together_async(): # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +# ===== Local Models ===== +@pytest.mark.skipif(model_settings.ollama_base_url is None, reason="Only run if OLLAMA_BASE_URL is set.") +@pytest.mark.asyncio +async def test_ollama(): + provider = OllamaProvider( + name="ollama", + base_url=model_settings.ollama_base_url, + api_key=None, + default_prompt_formatter=model_settings.default_prompt_formatter, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = await provider.list_embedding_models_async() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +@pytest.mark.skipif(model_settings.vllm_api_base is None, reason="Only run if VLLM_API_BASE is set.") +@pytest.mark.asyncio +async def test_vllm(): + provider = VLLMProvider(base_url=model_settings.vllm_api_base) + models = await provider.list_llm_models_async() + print(models) + + # TODO: Add back in, difficulty adding this to CI properly, need boto credentials # def test_anthropic_bedrock(): # from letta.settings import model_settings @@ -218,20 +191,122 @@ async def test_together_async(): # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_custom_anthropic(): +async def test_custom_anthropic(): provider = AnthropicProvider( name="custom_anthropic", api_key=model_settings.anthropic_api_key, ) - models = provider.list_llm_models() + models = await provider.list_llm_models_async() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" -@pytest.mark.skipif(model_settings.vllm_api_base is None, reason="Only run if VLLM_API_BASE is set.") -def test_vllm(): - provider = VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base) - models = provider.list_llm_models() - print(models) +def test_provider_context_window(): + """Test that providers implement context window methods correctly.""" + provider = OpenAIProvider( + name="openai", + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, + ) - provider.list_embedding_models() + # Test both sync and async context window methods + context_window = provider.get_model_context_window("gpt-4") + assert context_window is not None + assert isinstance(context_window, int) + assert context_window > 0 + + +@pytest.mark.asyncio +async def test_provider_context_window_async(): + """Test that providers implement async context window methods correctly.""" + provider = OpenAIProvider( + name="openai", + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, + ) + + context_window = await provider.get_model_context_window_async("gpt-4") + assert context_window is not None + assert isinstance(context_window, int) + assert context_window > 0 + + +def test_provider_handle_generation(): + """Test that providers generate handles correctly.""" + provider = OpenAIProvider( + name="test_openai", + api_key="test_key", + base_url="https://api.openai.com/v1", + ) + + # Test LLM handle + llm_handle = provider.get_handle("gpt-4") + assert llm_handle == "test_openai/gpt-4" + + # Test embedding handle + embedding_handle = provider.get_handle("text-embedding-ada-002", is_embedding=True) + assert embedding_handle == "test_openai/text-embedding-ada-002" + + +def test_provider_casting(): + """Test that providers can be cast to their specific subtypes.""" + from letta.schemas.enums import ProviderCategory, ProviderType + from letta.schemas.providers.base import Provider + + base_provider = Provider( + name="test_provider", + provider_type=ProviderType.openai, + provider_category=ProviderCategory.base, + api_key="test_key", + base_url="https://api.openai.com/v1", + ) + + cast_provider = base_provider.cast_to_subtype() + assert isinstance(cast_provider, OpenAIProvider) + assert cast_provider.name == "test_provider" + assert cast_provider.api_key == "test_key" + + +@pytest.mark.asyncio +async def test_provider_embedding_models_consistency(): + """Test that providers return consistent embedding model formats.""" + provider = OpenAIProvider( + name="openai", + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, + ) + + embedding_models = await provider.list_embedding_models_async() + if embedding_models: # Only test if provider supports embedding models + for model in embedding_models: + assert hasattr(model, "embedding_model") + assert hasattr(model, "embedding_endpoint_type") + assert hasattr(model, "embedding_endpoint") + assert hasattr(model, "embedding_dim") + assert hasattr(model, "handle") + assert model.handle.startswith(f"{provider.name}/") + + +@pytest.mark.asyncio +async def test_provider_llm_models_consistency(): + """Test that providers return consistent LLM model formats.""" + provider = OpenAIProvider( + name="openai", + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, + ) + + models = await provider.list_llm_models_async() + assert len(models) > 0 + + for model in models: + assert hasattr(model, "model") + assert hasattr(model, "model_endpoint_type") + assert hasattr(model, "model_endpoint") + assert hasattr(model, "context_window") + assert hasattr(model, "handle") + assert hasattr(model, "provider_name") + assert hasattr(model, "provider_category") + assert model.handle.startswith(f"{provider.name}/") + assert model.provider_name == provider.name + assert model.context_window > 0 diff --git a/tests/test_server.py b/tests/test_server.py index 23756386..ec9007cc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -991,7 +991,8 @@ def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_m assert len(agent_state.tool_rules) == len(base_tools + base_memory_tools) -def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools): +@pytest.mark.asyncio +async def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" actor = server.user_manager.get_user_or_default(user_id) @@ -1055,12 +1056,12 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to # Add all the base tools request.tool_ids = [b.id for b in base_tools] - agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor) + agent_state = await server.agent_manager.update_agent_async(agent_state.id, agent_update=request, actor=actor) assert len(agent_state.tools) == len(base_tools) # Remove one base tool request.tool_ids = [b.id for b in base_tools[:-2]] - agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor) + agent_state = await server.agent_manager.update_agent_async(agent_state.id, agent_update=request, actor=actor) assert len(agent_state.tools) == len(base_tools) - 2 diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index a228e250..0e5f4ed5 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -59,7 +59,7 @@ def test_get_allowed_tool_names_no_matching_rule_error(): solver = ToolRulesSolver(tool_rules=[init_rule]) solver.register_tool_call(UNRECOGNIZED_TOOL) - with pytest.raises(ValueError, match=f"No valid tools found based on tool rules."): + with pytest.raises(ValueError, match="No valid tools found based on tool rules."): solver.get_allowed_tool_names(set(), error_on_empty=True)