feat: Add more models for send_message tests (#1847)
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": false,
|
||||
"enable_reasoner": true,
|
||||
"max_reasoning_tokens": 1024
|
||||
}
|
||||
8
tests/configs/llm_model_configs/claude-3-7-sonnet.json
Normal file
8
tests/configs/llm_model_configs/claude-3-7-sonnet.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
7
tests/configs/llm_model_configs/openai-gpt-4o-mini.json
Normal file
7
tests/configs/llm_model_configs/openai-gpt-4o-mini.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4o-mini",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -10,6 +11,7 @@ from letta_client import AsyncLetta, Letta, Run
|
||||
from letta_client.types import AssistantMessage, ReasoningMessage
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
@@ -91,8 +93,27 @@ def agent_state(client: Letta) -> AgentState:
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
return llm_config
|
||||
|
||||
|
||||
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
|
||||
TESTED_MODELS: List[str] = ["openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022"]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"azure-gpt-4o-mini.json",
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-pro.json",
|
||||
"gemini-vertex.json",
|
||||
]
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
|
||||
def assert_tool_response_messages(messages: List[Any]) -> None:
|
||||
@@ -171,18 +192,18 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
def test_send_message_sync_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@@ -191,18 +212,18 @@ def test_send_message_sync_client(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
async def test_send_message_async_client(
|
||||
disable_e2b_api_key: Any,
|
||||
async_client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with an asynchronous client.
|
||||
Validates that the response messages match the expected sequence.
|
||||
"""
|
||||
await async_client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@@ -210,18 +231,18 @@ async def test_send_message_async_client(
|
||||
assert_tool_response_messages(response.messages)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
def test_send_message_streaming_sync_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@@ -231,18 +252,18 @@ def test_send_message_streaming_sync_client(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
async def test_send_message_streaming_async_client(
|
||||
disable_e2b_api_key: Any,
|
||||
async_client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with an asynchronous client.
|
||||
Validates that the streaming response chunks include the correct message types.
|
||||
"""
|
||||
await async_client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@@ -251,18 +272,18 @@ async def test_send_message_streaming_async_client(
|
||||
assert_streaming_tool_response_messages(chunks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
def test_send_message_job_sync_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
@@ -278,19 +299,19 @@ def test_send_message_job_sync_client(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
async def test_send_message_job_async_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
async_client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the asynchronous client.
|
||||
Waits for job completion and verifies that the resulting messages meet the expected format.
|
||||
"""
|
||||
await async_client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = await async_client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
|
||||
Reference in New Issue
Block a user