Files
letta-server/tests/test_providers.py
2025-05-20 16:00:37 -07:00

228 lines
7.4 KiB
Python

import pytest
from letta.schemas.providers import (
AnthropicProvider,
AzureProvider,
DeepSeekProvider,
GoogleAIProvider,
GoogleVertexProvider,
GroqProvider,
OpenAIProvider,
TogetherProvider,
)
from letta.settings import model_settings
def test_openai():
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
@pytest.mark.asyncio
async def test_openai_async():
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic():
provider = AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
@pytest.mark.asyncio
async def test_anthropic_async():
provider = AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_groq():
provider = GroqProvider(
name="groq",
api_key=model_settings.groq_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_azure():
provider = AzureProvider(
name="azure",
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# def test_ollama():
# provider = OllamaProvider(
# name="ollama",
# base_url=model_settings.ollama_base_url,
# api_key=None,
# default_prompt_formatter=model_settings.default_prompt_formatter,
# )
# models = provider.list_llm_models()
# assert len(models) > 0
# assert models[0].handle == f"{provider.name}/{models[0].model}"
#
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai():
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
@pytest.mark.asyncio
async def test_googleai_async():
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex():
provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_together():
provider = TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# TODO: We don't have embedding models on together for CI
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
@pytest.mark.asyncio
async def test_together_async():
provider = TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# TODO: We don't have embedding models on together for CI
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# TODO: Add back in, difficulty adding this to CI properly, need boto credentials
# def test_anthropic_bedrock():
# from letta.settings import model_settings
#
# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
# models = provider.list_llm_models()
# assert len(models) > 0
# assert models[0].handle == f"{provider.name}/{models[0].model}"
#
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
provider = AnthropicProvider(
name="custom_anthropic",
api_key=model_settings.anthropic_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()
# print(models)
#
# provider.list_embedding_models()