chore: migrate built in tools integration test to sdk v1 [LET-5980] (#5883)

* chore: migrate built in tools integration test to sdk v1

* fix

* remove trialing commas
This commit is contained in:
Christina Tong
2025-10-31 14:17:52 -07:00
committed by Caren Thomas
parent 255fdfecf2
commit 381ca5bde8
3 changed files with 344 additions and 38 deletions

View File

@@ -0,0 +1,313 @@
import json
import os
import threading
import time
import uuid
from unittest.mock import MagicMock, patch
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
yield url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
Uses system-level EXA_API_KEY setting.
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
run_code_tool = client.tools.list(name="run_code")[0]
web_search_tool = client.tools.list(name="web_search")[0]
agent_state_instance = client.agents.create(
name="test_builtin_tools_agent",
include_base_tools=False,
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["test_builtin_tools_agent"],
)
yield agent_state_instance
# ------------------------------
# Helper Functions and Constants
# ------------------------------
USER_MESSAGE_OTID = str(uuid.uuid4())
TEST_LANGUAGES = ["Python", "Javascript", "Typescript"]
EXPECTED_INTEGER_PARTITION_OUTPUT = "190569292"
# Reference implementation in Python, to embed in the user prompt
REFERENCE_CODE = """\
def reference_partition(n):
partitions = [1] + [0] * (n + 1)
for k in range(1, n + 1):
for i in range(k, n + 1):
partitions[i] += partitions[i - k]
return partitions[n]
"""
def reference_partition(n: int) -> int:
# Same logic, used to compute expected result in the test
partitions = [1] + [0] * (n + 1)
for k in range(1, n + 1):
for i in range(k, n + 1):
partitions[i] += partitions[i - k]
return partitions[n]
# ------------------------------
# Test Cases
# ------------------------------
@pytest.mark.parametrize("language", TEST_LANGUAGES, ids=TEST_LANGUAGES)
def test_run_code(
client: Letta,
agent_state: AgentState,
language: str,
) -> None:
"""
Sends a reference Python implementation, asks the model to translate & run it
in different languages, and verifies the exact partition(100) result.
"""
expected = str(reference_partition(100))
user_message = MessageCreateParam(
role="user",
content=(
"Here is a Python reference implementation:\n\n"
f"{REFERENCE_CODE}\n"
f"Please translate and execute this code in {language} to compute p(100), "
"and return **only** the result with no extra formatting."
),
otid=USER_MESSAGE_OTID,
)
response = client.agents.messages.send(
agent_id=agent_state.id,
messages=[user_message],
)
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert tool_returns, f"No ToolReturnMessage found for language: {language}"
returns = [m.tool_return for m in tool_returns]
assert any(expected in ret for ret in returns), (
f"For language={language!r}, expected to find '{expected}' in tool_return, but got {returns!r}"
)
@pytest.mark.asyncio(scope="function")
async def test_web_search() -> None:
"""Test web search tool with mocked Exa API."""
# create mock agent state with exa api key
mock_agent_state = MagicMock()
mock_agent_state.get_agent_env_vars_as_dict.return_value = {"EXA_API_KEY": "test-exa-key"}
# Mock Exa search result with education information
mock_exa_result = MagicMock()
mock_exa_result.results = [
MagicMock(
title="Charles Packer - UC Berkeley PhD in Computer Science",
url="https://example.com/charles-packer-profile",
published_date="2023-01-01",
author="UC Berkeley",
text=None,
highlights=["Charles Packer completed his PhD at UC Berkeley", "Research in artificial intelligence and machine learning"],
summary="Charles Packer is the CEO of Letta who earned his PhD in Computer Science from UC Berkeley, specializing in AI research.",
),
MagicMock(
title="Letta Leadership Team",
url="https://letta.com/team",
published_date="2023-06-01",
author="Letta",
text=None,
highlights=["CEO Charles Packer brings academic expertise"],
summary="Leadership team page featuring CEO Charles Packer's educational background.",
),
]
with patch("exa_py.Exa") as mock_exa_class:
# Setup mock
mock_exa_client = MagicMock()
mock_exa_class.return_value = mock_exa_client
mock_exa_client.search_and_contents.return_value = mock_exa_result
# create executor with mock dependencies
executor = LettaBuiltinToolExecutor(
message_manager=MagicMock(),
agent_manager=MagicMock(),
block_manager=MagicMock(),
run_manager=MagicMock(),
passage_manager=MagicMock(),
actor=MagicMock(),
)
# call web_search directly
result = await executor.web_search(
agent_state=mock_agent_state,
query="where did Charles Packer, CEO of Letta, go to school",
num_results=10,
include_text=False,
)
# Parse the JSON response from web_search
response_json = json.loads(result)
# Basic structure assertions for new Exa format
assert "query" in response_json, "Missing 'query' field in response"
assert "results" in response_json, "Missing 'results' field in response"
# Verify we got search results
results = response_json["results"]
assert len(results) == 2, "Should have found exactly 2 search results from mock"
# Check each result has the expected structure
found_education_info = False
for result in results:
assert "title" in result, "Result missing title"
assert "url" in result, "Result missing URL"
# text should not be present since include_text=False by default
assert "text" not in result or result["text"] is None, "Text should not be included by default"
# Check for education-related information in summary and highlights
result_text = ""
if "summary" in result and result["summary"]:
result_text += " " + result["summary"].lower()
if "highlights" in result and result["highlights"]:
for highlight in result["highlights"]:
result_text += " " + highlight.lower()
# Look for education keywords
if any(keyword in result_text for keyword in ["berkeley", "university", "phd", "ph.d", "education", "student"]):
found_education_info = True
assert found_education_info, "Should have found education-related information about Charles Packer"
# Verify Exa was called with correct parameters
mock_exa_class.assert_called_once_with(api_key="test-exa-key")
mock_exa_client.search_and_contents.assert_called_once()
call_args = mock_exa_client.search_and_contents.call_args
assert call_args[1]["type"] == "auto"
assert call_args[1]["text"] is False # Default is False now
@pytest.mark.asyncio(scope="function")
async def test_web_search_uses_exa():
"""Test that web search uses Exa API correctly."""
# create mock agent state with exa api key
mock_agent_state = MagicMock()
mock_agent_state.get_agent_env_vars_as_dict.return_value = {"EXA_API_KEY": "test-exa-key"}
# Mock exa search result
mock_exa_result = MagicMock()
mock_exa_result.results = [
MagicMock(
title="Test Result",
url="https://example.com/test",
published_date="2023-01-01",
author="Test Author",
text="This is test content from the search result.",
highlights=["This is a highlight"],
summary="This is a summary of the content.",
)
]
with patch("exa_py.Exa") as mock_exa_class:
# Mock Exa
mock_exa_client = MagicMock()
mock_exa_class.return_value = mock_exa_client
mock_exa_client.search_and_contents.return_value = mock_exa_result
# create executor with mock dependencies
executor = LettaBuiltinToolExecutor(
message_manager=MagicMock(),
agent_manager=MagicMock(),
block_manager=MagicMock(),
run_manager=MagicMock(),
passage_manager=MagicMock(),
actor=MagicMock(),
)
result = await executor.web_search(agent_state=mock_agent_state, query="test query", num_results=3, include_text=True)
# Verify Exa was called correctly
mock_exa_class.assert_called_once_with(api_key="test-exa-key")
mock_exa_client.search_and_contents.assert_called_once()
# Check the call arguments
call_args = mock_exa_client.search_and_contents.call_args
assert call_args[1]["query"] == "test query"
assert call_args[1]["num_results"] == 3
assert call_args[1]["type"] == "auto"
assert call_args[1]["text"] == True
# Verify the response format
response_json = json.loads(result)
assert "query" in response_json
assert "results" in response_json
assert response_json["query"] == "test query"
assert len(response_json["results"]) == 1

View File

@@ -1,5 +1,6 @@
import base64
import json
import logging
import os
import threading
import time
@@ -14,7 +15,7 @@ import pytest
import requests
from dotenv import load_dotenv
from letta_client import APIError, AsyncLetta, Letta
from letta_client.types import ToolReturnMessage
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta_client.types.agents import (
AssistantMessage,
HiddenReasoningMessage,
@@ -24,21 +25,16 @@ from letta_client.types.agents import (
ToolCallMessage,
UserMessage,
)
from letta_client.types.agents.image_content_param import ImageContentParam, SourceBase64Image, SourceURLImage
from letta_client.types.agents.image_content_param import ImageContentParam, SourceBase64Image
from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics
from letta_client.types.agents.text_content_param import TextContentParam
from letta.errors import LLMError
from letta.helpers.reasoning_helper import is_reasoning_completely_disabled
from letta.llm_api.openai_client import is_openai_reasoning_model
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.letta_message_content import Base64Image, ImageContent, TextContent, UrlImage
from letta.schemas.letta_request import LettaRequest
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
logger = get_logger(__name__)
logger = logging.getLogger(__name__)
# ------------------------------
# Helper Functions and Constants
@@ -68,8 +64,8 @@ def roll_dice(num_sides: int) -> int:
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_FORCE_REPLY: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.",
otid=USER_MESSAGE_OTID,
@@ -85,29 +81,29 @@ USER_MESSAGE_LONG_RESPONSE: str = (
"Successful teams celebrate victories together and learn from failures as a unit, creating a culture of continuous improvement. "
"Together, we can overcome challenges that would be insurmountable alone, achieving extraordinary results through the power of collaboration."
)
USER_MESSAGE_FORCE_LONG_REPLY: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_FORCE_LONG_REPLY: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=f"This is an automated test message. Call the send_message tool with exactly this message: '{USER_MESSAGE_LONG_RESPONSE}'",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_GREETING: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_GREETING: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="Hi!",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and send me a message with the outcome.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE_LONG: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_ROLL_DICE_LONG: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=(
"This is an automated test message. Call the roll_dice tool with 16 sides and send me a very detailed, comprehensive message about the outcome. "
@@ -123,8 +119,8 @@ USER_MESSAGE_ROLL_DICE_LONG: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE_GEMINI_FLASH: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_ROLL_DICE_GEMINI_FLASH: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=(
'This is an automated test message. First, call the roll_dice tool with exactly this JSON: {"num_sides": 16, "request_heartbeat": true}. '
@@ -134,8 +130,8 @@ USER_MESSAGE_ROLL_DICE_GEMINI_FLASH: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=(
"This is an automated test message. First, think long and hard about about why you're here, and your creator. "
@@ -158,8 +154,8 @@ USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreate] = [
BASE64_IMAGE = base64.standard_b64encode(
httpx.get("https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg").content
).decode("utf-8")
USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_BASE64_IMAGE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=[
ImageContentParam(type="image", source=SourceBase64Image(type="base64", data=BASE64_IMAGE, media_type="image/jpeg")),
@@ -1951,7 +1947,7 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
try:
client.agents.messages.send(
agent_id=temp_agent_state.id,
messages=[MessageCreate(role="user", content=philosophical_question)],
messages=[MessageCreateParam(role="user", content=philosophical_question)],
)
except Exception as e:
# if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):

View File

@@ -1,6 +1,7 @@
import asyncio
import itertools
import json
import logging
import os
import threading
import time
@@ -11,17 +12,13 @@ import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta
from letta_client.types import ToolReturnMessage
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta_client.types.agents import AssistantMessage, ReasoningMessage, Run, ToolCallMessage, UserMessage
from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
logger = get_logger(__name__)
logger = logging.getLogger(__name__)
# ------------------------------
@@ -77,22 +74,22 @@ def roll_dice(num_sides: int) -> int:
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_FORCE_REPLY: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=("This is an automated test message. Please call the roll_dice tool three times in parallel."),
otid=USER_MESSAGE_OTID,
@@ -501,7 +498,7 @@ async def agent_state(client: AsyncLetta) -> AgentState:
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
agent_state_instance = await client.agents.create(
agent_type=AgentType.letta_v1_agent,
agent_type="letta_v1_agent",
name="test_agent",
include_base_tools=False,
tool_ids=[dice_tool.id],