test: add new agent fixture to send message test (#2758)

This commit is contained in:
cthomas
2025-06-11 11:48:34 -07:00
committed by GitHub
parent c1255dc9d1
commit 7c7e2d62d7

View File

@@ -28,88 +28,6 @@ from letta_client.types import (
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def async_client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
async_client_instance = AsyncLetta(base_url=server_url)
yield async_client_instance
@pytest.fixture(scope="module")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=False,
tool_ids=[send_message_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
# ------------------------------
# Helper Functions and Constants
# ------------------------------
@@ -175,7 +93,7 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
]
all_configs = [
"openai-gpt-4o-mini.json",
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
"azure-gpt-4o-mini.json",
"claude-3-5-sonnet.json",
"claude-3-7-sonnet.json",
"claude-3-7-sonnet-extended.json",
@@ -377,19 +295,6 @@ def accumulate_chunks(chunks: List[Any]) -> List[Any]:
return [m for m in messages if m is not None]
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
run = client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "failed":
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
time.sleep(interval)
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
"""
Asserts that a list of message dictionaries contains the expected types and statuses.
@@ -406,6 +311,108 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
assert messages[1]["message_type"] == "assistant_message"
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def async_client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
async_client_instance = AsyncLetta(base_url=server_url)
yield async_client_instance
@pytest.fixture(scope="module")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
"""
client.tools.upsert_base_tools()
dice_tool = client.tools.upsert_from_function(func=roll_dice)
send_message_tool = client.tools.list(name="send_message")[0]
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=False,
tool_ids=[send_message_tool.id, dice_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
@pytest.fixture(scope="module")
def agent_state_no_tools(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with no tools.
"""
send_message_tool = client.tools.list(name="send_message")[0]
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=False,
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
# ------------------------------
# Test Cases
# ------------------------------
@@ -479,8 +486,6 @@ def test_tool_call(
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
dice_tool = client.tools.upsert_from_function(func=roll_dice)
client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
@@ -552,24 +557,22 @@ def test_base64_image_input(
def test_agent_loop_error(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
agent_state_no_tools: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that no new messages are persisted on error.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
tools = agent_state.tools
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
with pytest.raises(ApiError):
client.agents.messages.create(
agent_id=agent_state.id,
agent_id=agent_state_no_tools.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
assert len(messages_from_db) == 0
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
@pytest.mark.parametrize(
@@ -593,8 +596,7 @@ def test_step_streaming_greeting_with_assistant_message(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
messages = accumulate_chunks(list(response))
assert_greeting_with_assistant_message_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
@@ -622,8 +624,7 @@ def test_step_streaming_greeting_without_assistant_message(
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
messages = accumulate_chunks(list(response))
assert_greeting_without_assistant_message_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
@@ -644,16 +645,13 @@ def test_step_streaming_tool_call(
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
dice_tool = client.tools.upsert_from_function(func=roll_dice)
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
messages = accumulate_chunks(list(response))
assert_tool_call_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True)
@@ -667,26 +665,24 @@ def test_step_streaming_tool_call(
def test_step_stream_agent_loop_error(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
agent_state_no_tools: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that no new messages are persisted on error.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
tools = agent_state.tools
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
with pytest.raises(ApiError):
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
agent_id=agent_state_no_tools.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
list(response)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
assert len(messages_from_db) == 0
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
@pytest.mark.parametrize(
@@ -711,8 +707,7 @@ def test_token_streaming_greeting_with_assistant_message(
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
messages = accumulate_chunks(list(response))
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
@@ -741,8 +736,7 @@ def test_token_streaming_greeting_without_assistant_message(
use_assistant_message=False,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
messages = accumulate_chunks(list(response))
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
@@ -763,8 +757,6 @@ def test_token_streaming_tool_call(
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
dice_tool = client.tools.upsert_from_function(func=roll_dice)
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
@@ -772,8 +764,7 @@ def test_token_streaming_tool_call(
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
messages = accumulate_chunks(list(response))
assert_tool_call_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True)
@@ -787,19 +778,18 @@ def test_token_streaming_tool_call(
def test_token_streaming_agent_loop_error(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
agent_state_no_tools: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that no new messages are persisted on error.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
tools = agent_state.tools
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config, tool_ids=[])
try:
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
agent_id=agent_state_no_tools.id,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
@@ -807,9 +797,21 @@ def test_token_streaming_agent_loop_error(
except:
pass # only some models throw an error TODO: make this consistent
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
assert len(messages_from_db) == 0
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
run = client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "failed":
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
time.sleep(interval)
@pytest.mark.parametrize(
@@ -850,7 +852,6 @@ def test_async_greeting_with_assistant_message(
def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLMConfig):
"""Test that summarization is automatically triggered."""
llm_config.context_window = 3000
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
temp_agent_state = client.agents.create(