feat: Add more models for send_message tests (#1847)

This commit is contained in:
Matthew Zhou
2025-04-22 17:03:21 -07:00
committed by GitHub
parent 68a9e31eb1
commit c1644163be
4 changed files with 65 additions and 19 deletions

View File

@@ -0,0 +1,10 @@
{
"model": "claude-3-7-sonnet-20250219",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"context_window": 200000,
"put_inner_thoughts_in_kwargs": false,
"enable_reasoner": true,
"max_reasoning_tokens": 1024
}

View File

@@ -0,0 +1,8 @@
{
"model": "claude-3-7-sonnet-20250219",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"context_window": 200000,
"put_inner_thoughts_in_kwargs": true
}

View File

@@ -0,0 +1,7 @@
{
"context_window": 8192,
"model": "gpt-4o-mini",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null
}

View File

@@ -1,3 +1,4 @@
import json
import os
import threading
import time
@@ -10,6 +11,7 @@ from letta_client import AsyncLetta, Letta, Run
from letta_client.types import AssistantMessage, ReasoningMessage
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
# ------------------------------
# Fixtures
@@ -91,8 +93,27 @@ def agent_state(client: Letta) -> AgentState:
# Helper Functions and Constants
# ------------------------------
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
return llm_config
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
TESTED_MODELS: List[str] = ["openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022"]
all_configs = [
"openai-gpt-4o-mini.json",
"azure-gpt-4o-mini.json",
"claude-3-5-sonnet.json",
"claude-3-7-sonnet.json",
"claude-3-7-sonnet-extended.json",
"gemini-pro.json",
"gemini-vertex.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def assert_tool_response_messages(messages: List[Any]) -> None:
@@ -171,18 +192,18 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
# ------------------------------
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
def test_send_message_sync_client(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
client.agents.modify(agent_id=agent_state.id, model=model)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@@ -191,18 +212,18 @@ def test_send_message_sync_client(
@pytest.mark.asyncio
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
async def test_send_message_async_client(
disable_e2b_api_key: Any,
async_client: AsyncLetta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with an asynchronous client.
Validates that the response messages match the expected sequence.
"""
await async_client.agents.modify(agent_id=agent_state.id, model=model)
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = await async_client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@@ -210,18 +231,18 @@ async def test_send_message_async_client(
assert_tool_response_messages(response.messages)
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
def test_send_message_streaming_sync_client(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
client.agents.modify(agent_id=agent_state.id, model=model)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@@ -231,18 +252,18 @@ def test_send_message_streaming_sync_client(
@pytest.mark.asyncio
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
async def test_send_message_streaming_async_client(
disable_e2b_api_key: Any,
async_client: AsyncLetta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with an asynchronous client.
Validates that the streaming response chunks include the correct message types.
"""
await async_client.agents.modify(agent_id=agent_state.id, model=model)
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@@ -251,18 +272,18 @@ async def test_send_message_streaming_async_client(
assert_streaming_tool_response_messages(chunks)
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
def test_send_message_job_sync_client(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
client.agents.modify(agent_id=agent_state.id, model=model)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
@@ -278,19 +299,19 @@ def test_send_message_job_sync_client(
@pytest.mark.asyncio
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
async def test_send_message_job_async_client(
disable_e2b_api_key: Any,
client: Letta,
async_client: AsyncLetta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the asynchronous client.
Waits for job completion and verifies that the resulting messages meet the expected format.
"""
await async_client.agents.modify(agent_id=agent_state.id, model=model)
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = await async_client.agents.messages.create_async(
agent_id=agent_state.id,