feat: cutover repo to 1.0 sdk client LET-6256 (#6361)

feat: cutover repo to 1.0 sdk client
This commit is contained in:
cthomas
2025-11-24 18:39:26 -08:00
committed by Caren Thomas
parent 98edb3fe86
commit 7b0bd1cb13
54 changed files with 2385 additions and 10257 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "letta"
version = "0.14.0"
version = "0.14.1"
description = "Create LLM agents with long-term memory and custom tools"
authors = [
{name = "Letta Team", email = "contact@letta.com"},
@@ -43,7 +43,7 @@ dependencies = [
"llama-index>=0.12.2",
"llama-index-embeddings-openai>=0.3.1",
"anthropic>=0.75.0",
"letta-client>=0.1.319",
"letta-client>=1.1.2",
"openai>=1.99.9",
"opentelemetry-api==1.30.0",
"opentelemetry-sdk==1.30.0",

View File

@@ -1,10 +1,15 @@
import logging
import os
import threading
import time
from datetime import datetime, timezone
from typing import Generator
import pytest
import requests
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
from dotenv import load_dotenv
from letta_client import Letta
from letta.server.db import db_registry
from letta.services.organization_manager import OrganizationManager
@@ -16,6 +21,52 @@ def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)
@pytest.fixture(scope="session")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="session")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="session", autouse=True)
def disable_db_pooling_for_tests():
"""Disable database connection pooling for the entire test session."""

View File

@@ -3,19 +3,15 @@ import os
import threading
import time
import uuid
from typing import List
from unittest.mock import MagicMock, patch
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate
from letta_client.types import ToolReturnMessage
from letta_client import Letta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
from letta.settings import tool_settings
# ------------------------------
# Fixtures
@@ -76,9 +72,9 @@ def agent_state(client: Letta) -> AgentState:
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
run_code_tool = client.tools.list(name="run_code")[0]
web_search_tool = client.tools.list(name="web_search")[0]
send_message_tool = list(client.tools.list(name="send_message"))[0]
run_code_tool = list(client.tools.list(name="run_code"))[0]
web_search_tool = list(client.tools.list(name="web_search"))[0]
agent_state_instance = client.agents.create(
name="test_builtin_tools_agent",
include_base_tools=False,
@@ -94,23 +90,7 @@ def agent_state(client: Letta) -> AgentState:
# Helper Functions and Constants
# ------------------------------
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
llm_config = LLMConfig(**config_data)
return llm_config
USER_MESSAGE_OTID = str(uuid.uuid4())
all_configs = [
"openai-gpt-4o-mini.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
TEST_LANGUAGES = ["Python", "Javascript", "Typescript"]
EXPECTED_INTEGER_PARTITION_OUTPUT = "190569292"
@@ -152,7 +132,7 @@ def test_run_code(
"""
expected = str(reference_partition(100))
user_message = MessageCreate(
user_message = MessageCreateParam(
role="user",
content=(
"Here is a Python reference implementation:\n\n"

View File

@@ -1,22 +1,16 @@
import os
import threading
import time
import logging
import uuid
from typing import List
from typing import Any, List
from unittest.mock import patch
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AgentState, ApprovalCreate, Letta, LlmConfig, MessageCreate, Tool
from letta_client.core.api_error import ApiError
from letta_client import APIError, Letta
from letta_client.types import AgentState, MessageCreateParam, Tool
from letta_client.types.agents import ApprovalCreateParam
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.log import get_logger
from letta.schemas.enums import AgentType
logger = get_logger(__name__)
logger = logging.getLogger(__name__)
# ------------------------------
# Helper Functions and Constants
@@ -24,8 +18,8 @@ logger = get_logger(__name__)
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'."
USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_TEST_APPROVAL: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=USER_MESSAGE_CONTENT,
otid=USER_MESSAGE_OTID,
@@ -35,16 +29,16 @@ FAKE_REQUEST_ID = str(uuid.uuid4())
SECRET_CODE = str(740845635798344975)
USER_MESSAGE_FOLLOW_UP_OTID = str(uuid.uuid4())
USER_MESSAGE_FOLLOW_UP_CONTENT = "Thank you for the secret code."
USER_MESSAGE_FOLLOW_UP: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_FOLLOW_UP: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=USER_MESSAGE_FOLLOW_UP_CONTENT,
otid=USER_MESSAGE_FOLLOW_UP_OTID,
)
]
USER_MESSAGE_PARALLEL_TOOL_CALL_CONTENT = "This is an automated test message. Call the get_secret_code_tool 3 times in parallel for the following inputs: 'hello world', 'hello letta', 'hello test', and also call the roll_dice_tool once with a 16-sided dice."
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=USER_MESSAGE_PARALLEL_TOOL_CALL_CONTENT,
otid=USER_MESSAGE_OTID,
@@ -78,20 +72,41 @@ def roll_dice_tool(num_sides: int) -> str:
def accumulate_chunks(stream):
messages = []
current_message = None
prev_message_type = None
for chunk in stream:
current_message_type = chunk.message_type
# Handle chunks that might not have message_type (like pings)
if not hasattr(chunk, "message_type"):
continue
current_message_type = getattr(chunk, "message_type", None)
if prev_message_type != current_message_type:
messages.append(chunk)
# Save the previous message if it exists
if current_message is not None:
messages.append(current_message)
# Start a new message
current_message = chunk
else:
# Accumulate content for same message type (token streaming)
if current_message is not None and hasattr(current_message, "content") and hasattr(chunk, "content"):
current_message.content += chunk.content
prev_message_type = current_message_type
return messages
# Don't forget the last message
if current_message is not None:
messages.append(current_message)
return [m for m in messages if m is not None]
def approve_tool_call(client: Letta, agent_id: str, tool_call_id: str):
client.agents.messages.create(
agent_id=agent_id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -109,56 +124,11 @@ def approve_tool_call(client: Letta, agent_id: str, tool_call_id: str):
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
# Note: server_url and client fixtures are inherited from tests/conftest.py
@pytest.fixture(scope="function")
def approval_tool_fixture(client: Letta) -> Tool:
def approval_tool_fixture(client: Letta):
"""
Creates and returns a tool that requires approval for testing.
"""
@@ -173,7 +143,7 @@ def approval_tool_fixture(client: Letta) -> Tool:
@pytest.fixture(scope="function")
def dice_tool_fixture(client: Letta) -> Tool:
def dice_tool_fixture(client: Letta):
client.tools.upsert_base_tools()
dice_tool = client.tools.upsert_from_function(
func=roll_dice_tool,
@@ -191,19 +161,17 @@ def agent(client: Letta, approval_tool_fixture, dice_tool_fixture) -> AgentState
"""
agent_state = client.agents.create(
name="approval_test_agent",
agent_type=AgentType.letta_v1_agent,
agent_type="letta_v1_agent",
include_base_tools=False,
tool_ids=[approval_tool_fixture.id, dice_tool_fixture.id],
include_base_tool_rules=False,
tool_rules=[],
# parallel_tool_calls=True,
model="anthropic/claude-sonnet-4-5-20250929",
embedding="openai/text-embedding-3-small",
tags=["approval_test"],
)
agent_state = client.agents.modify(
agent_id=agent_state.id, llm_config=dict(agent_state.llm_config.model_dump(), **{"parallel_tool_calls": True})
)
# Enable parallel tool calls for testing
agent_state = client.agents.update(agent_id=agent_state.id, parallel_tool_calls=True)
yield agent_state
client.agents.delete(agent_id=agent_state.id)
@@ -215,11 +183,11 @@ def agent(client: Letta, approval_tool_fixture, dice_tool_fixture) -> AgentState
def test_send_approval_without_pending_request(client, agent):
with pytest.raises(ApiError, match="No tool call is currently awaiting approval"):
with pytest.raises(APIError, match="No tool call is currently awaiting approval"):
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy
approval_request_id=FAKE_REQUEST_ID, # legacy
approvals=[
@@ -240,10 +208,10 @@ def test_send_user_message_with_pending_request(client, agent):
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Please approve or deny the pending request before continuing"):
with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"):
client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="hi")],
messages=[MessageCreateParam(role="user", content="hi")],
)
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
@@ -255,11 +223,11 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Invalid tool call IDs"):
with pytest.raises(APIError, match="Invalid tool call IDs"):
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy
approval_request_id=FAKE_REQUEST_ID, # legacy
approvals=[
@@ -298,7 +266,7 @@ def test_invoke_approval_request(
assert messages[-1].tool_call.name == "get_secret_code_tool"
assert messages[-1].tool_calls is not None
assert len(messages[-1].tool_calls) == 1
assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool"
assert messages[-1].tool_calls[0].name == "get_secret_code_tool"
# v3/v1 path: approval request tool args must not include request_heartbeat
import json as _json
@@ -306,7 +274,7 @@ def test_invoke_approval_request(
_args = _json.loads(messages[-1].tool_call.arguments)
assert "request_heartbeat" not in _args
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
@@ -315,7 +283,7 @@ def test_invoke_approval_request_stream(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
stream_tokens=True,
@@ -330,7 +298,7 @@ def test_invoke_approval_request_stream(
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
approve_tool_call(client, agent.id, messages[-3].tool_call.tool_call_id)
@@ -346,10 +314,10 @@ def test_invoke_tool_after_turning_off_requires_approval(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -365,13 +333,13 @@ def test_invoke_tool_after_turning_off_requires_approval(
)
messages = accumulate_chunks(response)
client.agents.tools.modify_approval(
client.agents.tools.update_approval(
agent_id=agent.id,
tool_name=approval_tool_fixture.name,
requires_approval=False,
body_requires_approval=False,
)
response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
response = client.agents.messages.stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
messages = accumulate_chunks(response)
@@ -420,10 +388,10 @@ def test_approve_tool_call_request(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -452,7 +420,7 @@ def test_approve_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
@@ -460,7 +428,7 @@ def test_approve_cursor_fetch(
last_message_id = response.messages[0].id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert messages[0].message_type == "user_message"
assert messages[-1].message_type == "approval_request_message"
# Ensure no request_heartbeat on approval request
@@ -472,7 +440,7 @@ def test_approve_cursor_fetch(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -486,12 +454,12 @@ def test_approve_cursor_fetch(
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
assert messages[0].message_type == "approval_response_message"
assert messages[0].approval_request_id == tool_call_id
assert messages[0].approve is True
assert messages[0].approvals[0]["approve"] is True
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0].approve is True
assert messages[0].approvals[0].tool_call_id == tool_call_id
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
@@ -506,10 +474,10 @@ def test_approve_with_context_check(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -527,7 +495,7 @@ def test_approve_with_context_check(
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
@@ -547,7 +515,7 @@ def test_approve_and_follow_up(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -561,7 +529,7 @@ def test_approve_and_follow_up(
],
)
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
@@ -587,10 +555,10 @@ def test_approve_and_follow_up_with_error(
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -605,18 +573,11 @@ def test_approve_and_follow_up_with_error(
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
print("\n\nmessages:\n\n")
for m in messages:
print(m)
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
with pytest.raises(APIError, match="TEST: Mocked error"):
messages = accumulate_chunks(response)
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
@@ -648,10 +609,10 @@ def test_deny_tool_call_request(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -680,7 +641,7 @@ def test_deny_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
@@ -688,7 +649,7 @@ def test_deny_cursor_fetch(
last_message_id = response.messages[0].id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert messages[0].message_type == "user_message"
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call.tool_call_id == tool_call_id
@@ -701,7 +662,7 @@ def test_deny_cursor_fetch(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -717,11 +678,11 @@ def test_deny_cursor_fetch(
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["approve"] == False
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0]["reason"] == f"You don't need to call the tool, the secret code is {SECRET_CODE}"
assert messages[0].approvals[0].approve == False
assert messages[0].approvals[0].tool_call_id == tool_call_id
assert messages[0].approvals[0].reason == f"You don't need to call the tool, the secret code is {SECRET_CODE}"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "error"
@@ -736,10 +697,10 @@ def test_deny_with_context_check(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
@@ -759,7 +720,7 @@ def test_deny_with_context_check(
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
@@ -779,7 +740,7 @@ def test_deny_and_follow_up(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -795,7 +756,7 @@ def test_deny_and_follow_up(
],
)
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
@@ -821,10 +782,10 @@ def test_deny_and_follow_up_with_error(
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -841,15 +802,11 @@ def test_deny_and_follow_up_with_error(
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
with pytest.raises(APIError, match="TEST: Mocked error"):
messages = accumulate_chunks(response)
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
@@ -877,10 +834,10 @@ def test_client_side_tool_call_request(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -911,7 +868,7 @@ def test_client_side_tool_call_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
@@ -919,7 +876,7 @@ def test_client_side_tool_call_cursor_fetch(
last_message_id = response.messages[0].id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert messages[0].message_type == "user_message"
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call.tool_call_id == tool_call_id
@@ -932,7 +889,7 @@ def test_client_side_tool_call_cursor_fetch(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -948,12 +905,12 @@ def test_client_side_tool_call_cursor_fetch(
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["type"] == "tool"
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0]["tool_return"] == SECRET_CODE
assert messages[0].approvals[0]["status"] == "success"
assert messages[0].approvals[0].type == "tool"
assert messages[0].approvals[0].tool_call_id == tool_call_id
assert messages[0].approvals[0].tool_return == SECRET_CODE
assert messages[0].approvals[0].status == "success"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
assert messages[1].tool_call_id == tool_call_id
@@ -970,10 +927,10 @@ def test_client_side_tool_call_with_context_check(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
@@ -993,7 +950,7 @@ def test_client_side_tool_call_with_context_check(
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
@@ -1013,7 +970,7 @@ def test_client_side_tool_call_and_follow_up(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -1029,7 +986,7 @@ def test_client_side_tool_call_and_follow_up(
],
)
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
@@ -1055,10 +1012,10 @@ def test_client_side_tool_call_and_follow_up_with_error(
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -1075,15 +1032,11 @@ def test_client_side_tool_call_and_follow_up_with_error(
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
with pytest.raises(APIError, match="TEST: Mocked error"):
messages = accumulate_chunks(response)
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
@@ -1100,7 +1053,7 @@ def test_parallel_tool_calling(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
@@ -1111,32 +1064,32 @@ def test_parallel_tool_calling(
assert messages is not None
assert messages[-2].message_type == "tool_call_message"
assert len(messages[-2].tool_calls) == 1
assert messages[-2].tool_calls[0]["name"] == "roll_dice_tool"
assert "6" in messages[-2].tool_calls[0]["arguments"]
dice_tool_call_id = messages[-2].tool_calls[0]["tool_call_id"]
assert messages[-2].tool_calls[0].name == "roll_dice_tool"
assert "6" in messages[-2].tool_calls[0].arguments
dice_tool_call_id = messages[-2].tool_calls[0].tool_call_id
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call is not None
assert messages[-1].tool_call.name == "get_secret_code_tool"
assert len(messages[-1].tool_calls) == 3
assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool"
assert "hello world" in messages[-1].tool_calls[0]["arguments"]
approve_tool_call_id = messages[-1].tool_calls[0]["tool_call_id"]
assert messages[-1].tool_calls[1]["name"] == "get_secret_code_tool"
assert "hello letta" in messages[-1].tool_calls[1]["arguments"]
deny_tool_call_id = messages[-1].tool_calls[1]["tool_call_id"]
assert messages[-1].tool_calls[2]["name"] == "get_secret_code_tool"
assert "hello test" in messages[-1].tool_calls[2]["arguments"]
client_side_tool_call_id = messages[-1].tool_calls[2]["tool_call_id"]
assert messages[-1].tool_calls[0].name == "get_secret_code_tool"
assert "hello world" in messages[-1].tool_calls[0].arguments
approve_tool_call_id = messages[-1].tool_calls[0].tool_call_id
assert messages[-1].tool_calls[1].name == "get_secret_code_tool"
assert "hello letta" in messages[-1].tool_calls[1].arguments
deny_tool_call_id = messages[-1].tool_calls[1].tool_call_id
assert messages[-1].tool_calls[2].name == "get_secret_code_tool"
assert "hello test" in messages[-1].tool_calls[2].arguments
client_side_tool_call_id = messages[-1].tool_calls[2].tool_call_id
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -1168,16 +1121,16 @@ def test_parallel_tool_calling(
assert messages[0].message_type == "tool_return_message"
assert len(messages[0].tool_returns) == 4
for tool_return in messages[0].tool_returns:
if tool_return["tool_call_id"] == approve_tool_call_id:
assert tool_return["status"] == "success"
elif tool_return["tool_call_id"] == deny_tool_call_id:
assert tool_return["status"] == "error"
elif tool_return["tool_call_id"] == client_side_tool_call_id:
assert tool_return["status"] == "success"
assert tool_return["tool_return"] == SECRET_CODE
if tool_return.tool_call_id == approve_tool_call_id:
assert tool_return.status == "success"
elif tool_return.tool_call_id == deny_tool_call_id:
assert tool_return.status == "error"
elif tool_return.tool_call_id == client_side_tool_call_id:
assert tool_return.status == "success"
assert tool_return.tool_return == SECRET_CODE
else:
assert tool_return["tool_call_id"] == dice_tool_call_id
assert tool_return["status"] == "success"
assert tool_return.tool_call_id == dice_tool_call_id
assert tool_return.status == "success"
if len(messages) == 3:
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
@@ -1187,9 +1140,9 @@ def test_parallel_tool_calling(
assert messages[3].message_type == "tool_return_message"
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert len(messages) > 6
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
@@ -1199,7 +1152,7 @@ def test_parallel_tool_calling(
assert messages[5].message_type == "approval_response_message"
assert messages[6].message_type == "tool_return_message"
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,

View File

@@ -8,7 +8,10 @@ from pathlib import Path
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate, ToolCallMessage, ToolReturnMessage
from letta_client import Letta
from letta_client.types import MessageCreateParam
from letta_client.types.agents.tool_call_message import ToolCallMessage
from letta_client.types.agents.tool_return import ToolReturn
from letta.functions.mcp_client.types import StdioServerConfig
from letta.schemas.agent import AgentState
@@ -166,7 +169,7 @@ def test_mcp_echo_tool(client: Letta, agent_state: AgentState):
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content=f"Use the echo tool to echo back this exact message: '{test_message}'",
)
@@ -182,7 +185,7 @@ def test_mcp_echo_tool(client: Letta, agent_state: AgentState):
assert echo_call is not None, f"No echo tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}"
# Check for tool return message
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
tool_returns = [m for m in response.messages if isinstance(m, ToolReturn)]
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
# Find the return for the echo call
@@ -204,7 +207,7 @@ def test_mcp_add_tool(client: Letta, agent_state: AgentState):
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content=f"Use the add tool to add {a} and {b}.",
)
@@ -220,7 +223,7 @@ def test_mcp_add_tool(client: Letta, agent_state: AgentState):
assert add_call is not None, f"No add tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}"
# Check for tool return message
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
tool_returns = [m for m in response.messages if isinstance(m, ToolReturn)]
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
# Find the return for the add call
@@ -239,7 +242,7 @@ def test_mcp_multiple_tools_in_sequence(client: Letta, agent_state: AgentState):
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content="First use the add tool to add 10 and 20. Then use the echo tool to echo back the result you got from the add tool.",
)
@@ -256,7 +259,7 @@ def test_mcp_multiple_tools_in_sequence(client: Letta, agent_state: AgentState):
assert "echo" in tool_names, f"echo tool not called. Tools called: {tool_names}"
# Check for tool return messages
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
tool_returns = [m for m in response.messages if isinstance(m, ToolReturn)]
assert len(tool_returns) >= 2, f"Expected at least 2 tool returns, got {len(tool_returns)}"
# Verify all tools succeeded
@@ -339,7 +342,7 @@ def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_s
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user", content='Use the get_parameter_type_description tool with preset "a" to get parameter information.'
)
],
@@ -351,7 +354,7 @@ def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_s
complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None)
assert complex_call is not None, f"No get_parameter_type_description call found. Calls: {[m.tool_call.name for m in tool_calls]}"
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
tool_returns = [m for m in response.messages if isinstance(m, ToolReturn)]
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None)
@@ -363,7 +366,7 @@ def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_s
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content="Use the get_parameter_type_description tool with these arguments: "
'preset="b", connected_service_descriptor="test-service", '
@@ -379,7 +382,7 @@ def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_s
complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None)
assert complex_call is not None, "No get_parameter_type_description call found for nested test"
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
tool_returns = [m for m in response.messages if isinstance(m, ToolReturn)]
complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None)
assert complex_return is not None, "No tool return found for complex nested call"
assert complex_return.status == "success", f"Complex nested call failed with status: {complex_return.status}"

View File

@@ -1,4 +1,4 @@
import asyncio
import ast
import json
import os
import threading
@@ -8,13 +8,9 @@ import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta_client.types.agents import SystemMessage
from letta.config import LettaConfig
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
from letta.schemas.tool import Tool
from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
from tests.helpers.utils import retry_until_success
@@ -23,7 +19,7 @@ def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
and polls until it's accepting connections.
"""
def _run_server() -> None:
@@ -55,17 +51,6 @@ def server_url() -> str:
yield url
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
return server
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
@@ -84,10 +69,11 @@ def remove_stale_agents(client):
@pytest.fixture(scope="function")
def agent_obj(client):
def agent_obj(client: Letta) -> AgentState:
"""Create a test agent that we can call functions on"""
send_message_to_agent_tool = client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0]
send_message_to_agent_tool = list(client.tools.list(name="send_message_to_agent_and_wait_for_reply"))[0]
agent_state_instance = client.agents.create(
agent_type="letta_v1_agent",
include_base_tools=True,
tool_ids=[send_message_to_agent_tool.id],
model="openai/gpt-4o",
@@ -96,13 +82,12 @@ def agent_obj(client):
)
yield agent_state_instance
# client.agents.delete(agent_state_instance.id)
@pytest.fixture(scope="function")
def other_agent_obj(client):
def other_agent_obj(client: Letta) -> AgentState:
"""Create another test agent that we can call functions on"""
agent_state_instance = client.agents.create(
agent_type="letta_v1_agent",
include_base_tools=True,
include_multi_agent_tools=False,
model="openai/gpt-4o",
@@ -112,11 +97,9 @@ def other_agent_obj(client):
yield agent_state_instance
# client.agents.delete(agent_state_instance.id)
@pytest.fixture
def roll_dice_tool(client):
def roll_dice_tool(client: Letta):
def roll_dice():
"""
Rolls a 6 sided die.
@@ -126,19 +109,7 @@ def roll_dice_tool(client):
"""
return "Rolled a 5!"
# Set up tool details
source_code = parse_source_code(roll_dice)
source_type = "python"
description = "test_description"
tags = ["test"]
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
# Use SDK method to create tool from function
tool = client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
@@ -146,86 +117,110 @@ def roll_dice_tool(client):
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
def test_send_message_to_agent(client: Letta, agent_obj: AgentState, other_agent_obj: AgentState):
secret_word = "banana"
actor = asyncio.run(server.user_manager.get_actor_or_default_async())
# Encourage the agent to send a message to the other agent_obj with the secret string
response = client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
{
"role": "user",
"content": f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!",
}
MessageCreateParam(
role="user",
content=f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!",
)
],
)
# Conversation search the other agent
messages = asyncio.run(
server.get_agent_recall(
user_id=actor.id,
agent_id=other_agent_obj.id,
reverse=True,
return_message_object=False,
)
)
# Get messages from the other agent
messages_page = client.agents.messages.list(agent_id=other_agent_obj.id)
messages = messages_page.items
# Check for the presence of system message
# Check for the presence of system message with secret word
found_secret = False
for m in reversed(messages):
print(f"\n\n {other_agent_obj.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
if secret_word in m.content:
found_secret = True
break
assert found_secret, f"Secret word '{secret_word}' not found in system messages of agent {other_agent_obj.id}"
# Search the sender agent for the response from another agent
in_context_messages = asyncio.run(AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor))
in_context_messages_page = client.agents.messages.list(agent_id=agent_obj.id)
in_context_messages = in_context_messages_page.items
found = False
target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': ["
for m in in_context_messages:
if target_snippet in m.content[0].text:
found = True
break
# Check ToolReturnMessage for the response
if isinstance(m, ToolReturnMessage):
if target_snippet in m.tool_return:
found = True
break
# Handle different message content structures
elif hasattr(m, "content"):
if isinstance(m.content, list) and len(m.content) > 0:
content_text = m.content[0].text if hasattr(m.content[0], "text") else str(m.content[0])
else:
content_text = str(m.content)
if target_snippet in content_text:
found = True
break
joined = "\n".join([m.content[0].text for m in in_context_messages[1:]])
print(f"In context messages of the sender agent (without system):\n\n{joined}")
if not found:
# Print debug info
joined = "\n".join(
[
str(
m.content[0].text
if hasattr(m, "content") and isinstance(m.content, list) and len(m.content) > 0 and hasattr(m.content[0], "text")
else m.content
if hasattr(m, "content")
else f"<{type(m).__name__}>"
)
for m in in_context_messages[1:]
]
)
print(f"In context messages of the sender agent (without system):\n\n{joined}")
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
MessageCreateParam(
role="user",
content="So what did the other agent say?",
)
],
)
print(response.messages)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client):
def test_send_message_to_agents_with_tags_simple(client: Letta):
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
secret_word = "banana"
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
send_message_to_agents_matching_tags_tool_id = list(client.tools.list(name="send_message_to_agents_matching_tags"))[0].id
manager_agent_state = client.agents.create(
agent_type="letta_v1_agent",
name="manager_agent",
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 3 non-matching worker agents (These should NOT get the message)
# Create 2 non-matching worker agents (These should NOT get the message)
worker_agents_123 = []
for idx in range(2):
worker_agent_state = client.agents.create(
agent_type="letta_v1_agent",
name=f"not_worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_123,
@@ -234,10 +229,11 @@ def test_send_message_to_agents_with_tags_simple(client):
)
worker_agents_123.append(worker_agent_state)
# Create 3 worker agents that should get the message
# Create 2 worker agents that should get the message
worker_agents_456 = []
for idx in range(2):
worker_agent_state = client.agents.create(
agent_type="letta_v1_agent",
name=f"worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_456,
@@ -250,69 +246,83 @@ def test_send_message_to_agents_with_tags_simple(client):
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
}
MessageCreateParam(
role="user",
content=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
)
],
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
tool_response = ast.literal_eval(m.tool_return)
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents_456)
# We can break after this, the ToolReturnMessage after is not related
# Verify responses from all expected worker agents
worker_agent_ids = {agent.id for agent in worker_agents_456}
returned_agent_ids = set()
for json_str in tool_response:
response_obj = json.loads(json_str)
assert response_obj["agent_id"] in worker_agent_ids
assert response_obj["response_messages"] != ["<no response>"]
returned_agent_ids.add(response_obj["agent_id"])
break
# Conversation search the worker agents
# Check messages in the worker agents that should have received the message
for agent_state in worker_agents_456:
messages = client.agents.messages.list(agent_state.id)
messages_page = client.agents.messages.list(agent_state.id)
messages = messages_page.items
# Check for the presence of system message
found_secret = False
for m in reversed(messages):
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
if secret_word in m.content:
found_secret = True
break
assert found_secret, f"Secret word not found in messages for agent {agent_state.id}"
# Ensure it's NOT in the non matching worker agents
for agent_state in worker_agents_123:
messages = client.agents.messages.list(agent_state.id)
messages_page = client.agents.messages.list(agent_state.id)
messages = messages_page.items
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word not in m.content
assert secret_word not in m.content, f"Secret word should not be in agent {agent_state.id}"
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
MessageCreateParam(
role="user",
content="So what did the other agent say?",
)
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
def test_send_message_to_agents_with_tags_complex_tool_use(client: Letta, roll_dice_tool):
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
send_message_to_agents_matching_tags_tool_id = list(client.tools.list(name="send_message_to_agents_matching_tags"))[0].id
manager_agent_state = client.agents.create(
agent_type="letta_v1_agent",
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 3 worker agents
# Create 2 worker agents
worker_agents = []
worker_tags = ["dice-rollers"]
for _ in range(2):
worker_agent_state = client.agents.create(
agent_type="letta_v1_agent",
include_multi_agent_tools=False,
tags=worker_tags,
tool_ids=[roll_dice_tool.id],
@@ -326,30 +336,40 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": broadcast_message,
}
MessageCreateParam(
role="user",
content=broadcast_message,
)
],
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
# Parse tool_return string to get list of responses
tool_response = ast.literal_eval(m.tool_return)
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents)
# We can break after this, the ToolReturnMessage after is not related
# Verify responses from all expected worker agents
worker_agent_ids = {agent.id for agent in worker_agents}
returned_agent_ids = set()
all_responses = []
for json_str in tool_response:
response_obj = json.loads(json_str)
assert response_obj["agent_id"] in worker_agent_ids
assert response_obj["response_messages"] != ["<no response>"]
returned_agent_ids.add(response_obj["agent_id"])
all_responses.extend(response_obj["response_messages"])
break
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
MessageCreateParam(
role="user",
content="So what did the other agent say?",
)
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))

File diff suppressed because it is too large Load Diff

View File

@@ -1,45 +1,22 @@
import asyncio
import base64
import itertools
import json
import logging
import os
import threading
import time
import uuid
from contextlib import contextmanager
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict, List, Tuple
from unittest.mock import patch
from typing import Any, List, Tuple
import httpx
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta, Letta, LettaRequest, MessageCreate, Run
from letta_client.core.api_error import ApiError
from letta_client.types import (
AssistantMessage,
Base64Image,
HiddenReasoningMessage,
ImageContent,
LettaMessageUnion,
LettaStopReason,
LettaUsageStatistics,
ReasoningMessage,
TextContent,
ToolCallMessage,
ToolReturnMessage,
UrlImage,
UserMessage,
)
from letta_client import AsyncLetta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta_client.types.agents import AssistantMessage, ReasoningMessage, Run, ToolCallMessage, UserMessage
from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType, JobStatus
from letta.schemas.letta_message import LettaPing
from letta.schemas.llm_config import LLMConfig
logger = get_logger(__name__)
logger = logging.getLogger(__name__)
# ------------------------------
@@ -56,17 +33,17 @@ all_configs = [
]
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
def get_model_config(filename: str, model_settings_dir: str = "tests/model_settings") -> Tuple[str, dict]:
"""Load a model_settings file and return the handle and settings dict."""
filename = os.path.join(model_settings_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
llm_config = LLMConfig(**config_data)
return llm_config
return config_data["handle"], config_data.get("model_settings", {})
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
TESTED_MODEL_CONFIGS: List[Tuple[str, dict]] = [get_model_config(fn) for fn in filenames]
def roll_dice(num_sides: int) -> int:
@@ -84,24 +61,27 @@ def roll_dice(num_sides: int) -> int:
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_FORCE_REPLY: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [
MessageCreate(
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=("This is an automated test message. Please call the roll_dice tool three times in parallel."),
content=(
"This is an automated test message. Please call the roll_dice tool EXACTLY three times in parallel - no more, no less. "
"Call it with num_sides=6, num_sides=12, and num_sides=20. Make all three calls at the same time in a single response."
),
otid=USER_MESSAGE_OTID,
)
]
@@ -109,7 +89,8 @@ USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [
def assert_greeting_response(
messages: List[Any],
llm_config: LLMConfig,
model_handle: str,
model_settings: dict,
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
@@ -124,7 +105,7 @@ def assert_greeting_response(
]
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, streaming=streaming, from_db=from_db
model_handle, model_settings, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
@@ -138,7 +119,7 @@ def assert_greeting_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(llm_config):
if is_reasoner_model(model_handle, model_settings):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
@@ -169,7 +150,8 @@ def assert_greeting_response(
def assert_tool_call_response(
messages: List[Any],
llm_config: LLMConfig,
model_handle: str,
model_settings: dict,
streaming: bool = False,
from_db: bool = False,
with_cancellation: bool = False,
@@ -190,7 +172,7 @@ def assert_tool_call_response(
if not with_cancellation:
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, tool_call=True, streaming=streaming, from_db=from_db
model_handle, model_settings, tool_call=True, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
@@ -208,7 +190,7 @@ def assert_tool_call_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(llm_config):
if is_reasoner_model(model_handle, model_settings):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
@@ -219,7 +201,7 @@ def assert_tool_call_response(
# Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call
if (
(llm_config.model == "claude-sonnet-4-5-20250929" or llm_config.model.startswith("claude-opus-4-1"))
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle)
and index < len(messages)
and isinstance(messages[index], AssistantMessage)
):
@@ -253,7 +235,7 @@ def assert_tool_call_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(llm_config):
if is_reasoner_model(model_handle, model_settings):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
@@ -279,37 +261,94 @@ def assert_tool_call_response(
assert messages[index].step_count > 0
async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]:
async def accumulate_chunks(chunks, verify_token_streaming: bool = False) -> List[Any]:
"""
Accumulates chunks into a list of messages.
Handles both async iterators and raw SSE strings.
"""
messages = []
current_message = None
prev_message_type = None
chunk_count = 0
async for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
messages.append(current_message)
if (
prev_message_type
and verify_token_streaming
and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
):
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}. Messages: {messages}"
current_message = None
chunk_count = 0
if current_message is None:
current_message = chunk
else:
pass # TODO: actually accumulate the chunks. For now we only care about the count
prev_message_type = current_message_type
chunk_count += 1
messages.append(current_message)
if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]:
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
return [m for m in messages if m is not None]
# Handle raw SSE string from runs.messages.stream()
if isinstance(chunks, str):
import json
for line in chunks.strip().split("\n"):
if line.startswith("data: ") and line != "data: [DONE]":
try:
data = json.loads(line[6:]) # Remove 'data: ' prefix
if "message_type" in data:
# Create proper message type objects
message_type = data.get("message_type")
if message_type == "assistant_message":
from letta_client.types.agents import AssistantMessage
chunk = AssistantMessage(**data)
elif message_type == "reasoning_message":
from letta_client.types.agents import ReasoningMessage
chunk = ReasoningMessage(**data)
elif message_type == "tool_call_message":
from letta_client.types.agents import ToolCallMessage
chunk = ToolCallMessage(**data)
elif message_type == "tool_return_message":
from letta_client.types import ToolReturnMessage
chunk = ToolReturnMessage(**data)
elif message_type == "user_message":
from letta_client.types.agents import UserMessage
chunk = UserMessage(**data)
elif message_type == "stop_reason":
from letta_client.types.agents.letta_streaming_response import LettaStopReason
chunk = LettaStopReason(**data)
elif message_type == "usage_statistics":
from letta_client.types.agents.letta_streaming_response import LettaUsageStatistics
chunk = LettaUsageStatistics(**data)
else:
chunk = type("Chunk", (), data)() # Fallback for unknown types
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
if current_message is not None:
messages.append(current_message)
current_message = chunk
else:
# Accumulate content for same message type
if hasattr(current_message, "content") and hasattr(chunk, "content"):
current_message.content += chunk.content
prev_message_type = current_message_type
except json.JSONDecodeError:
continue
if current_message is not None:
messages.append(current_message)
else:
# Handle async iterator from agents.messages.stream()
async for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
if current_message is not None:
messages.append(current_message)
current_message = chunk
else:
# Accumulate content for same message type
if hasattr(current_message, "content") and hasattr(chunk, "content"):
current_message.content += chunk.content
prev_message_type = current_message_type
if current_message is not None:
messages.append(current_message)
return messages
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5):
@@ -334,7 +373,7 @@ async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: floa
def get_expected_message_count_range(
llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False
model_handle: str, model_settings: dict, tool_call: bool = False, streaming: bool = False, from_db: bool = False
) -> Tuple[int, int]:
"""
Returns the expected range of number of messages for a given LLM configuration. Uses range to account for possible variations in the number of reasoning messages.
@@ -363,23 +402,26 @@ def get_expected_message_count_range(
expected_message_count = 1
expected_range = 0
if is_reasoner_model(llm_config):
if is_reasoner_model(model_handle, model_settings):
# reasoning message
expected_range += 1
if tool_call:
# check for sonnet 4.5 or opus 4.1 specifically
is_sonnet_4_5_or_opus_4_1 = (
llm_config.model_endpoint_type == "anthropic"
and llm_config.enable_reasoner
and (llm_config.model.startswith("claude-sonnet-4-5") or llm_config.model.startswith("claude-opus-4-1"))
model_settings.get("provider_type") == "anthropic"
and model_settings.get("thinking", {}).get("type") == "enabled"
and ("claude-sonnet-4-5" in model_handle or "claude-opus-4-1" in model_handle)
)
if is_sonnet_4_5_or_opus_4_1 or not LLMConfig.is_anthropic_reasoning_model(llm_config):
is_anthropic_reasoning = (
model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled"
)
if is_sonnet_4_5_or_opus_4_1 or not is_anthropic_reasoning:
# sonnet 4.5 and opus 4.1 return a reasoning message before the final assistant message
# so do the other native reasoning models
expected_range += 1
# opus 4.1 generates an extra AssistantMessage before the tool call
if llm_config.model.startswith("claude-opus-4-1"):
if "claude-opus-4-1" in model_handle:
expected_range += 1
if tool_call:
@@ -397,13 +439,34 @@ def get_expected_message_count_range(
return expected_message_count, expected_message_count + expected_range
def is_reasoner_model(llm_config: LLMConfig) -> bool:
return (
(LLMConfig.is_openai_reasoning_model(llm_config) and llm_config.reasoning_effort == "high")
or LLMConfig.is_anthropic_reasoning_model(llm_config)
or LLMConfig.is_google_vertex_reasoning_model(llm_config)
or LLMConfig.is_google_ai_reasoning_model(llm_config)
def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
"""Check if the model is a reasoning model based on its handle and settings."""
# OpenAI reasoning models with high reasoning effort
is_openai_reasoning = (
model_settings.get("provider_type") == "openai"
and (
"gpt-5" in model_handle
or "o1" in model_handle
or "o3" in model_handle
or "o4-mini" in model_handle
or "gpt-4.1" in model_handle
)
and model_settings.get("reasoning", {}).get("reasoning_effort") == "high"
)
# Anthropic models with thinking enabled
is_anthropic_reasoning = (
model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled"
)
# Google Vertex models with thinking config
is_google_vertex_reasoning = (
model_settings.get("provider_type") == "google_vertex" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
)
# Google AI models with thinking config
is_google_ai_reasoning = (
model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
)
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning
# ------------------------------
@@ -466,7 +529,7 @@ async def agent_state(client: AsyncLetta) -> AgentState:
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
agent_state_instance = await client.agents.create(
agent_type=AgentType.letta_v1_agent,
agent_type="letta_v1_agent",
name="test_agent",
include_base_tools=False,
tool_ids=[dice_tool.id],
@@ -485,9 +548,9 @@ async def agent_state(client: AsyncLetta) -> AgentState:
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
"model_config",
TESTED_MODEL_CONFIGS,
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
@@ -495,11 +558,13 @@ async def test_greeting(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
model_config: Tuple[str, dict],
send_type: str,
) -> None:
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
model_handle, model_settings = model_config
last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
last_message = last_message_page.items[0] if last_message_page.items else None
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
if send_type == "step":
response = await client.agents.messages.create(
@@ -507,50 +572,55 @@ async def test_greeting(
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = response.messages
run_id = next((m.run_id for m in messages if hasattr(m, "run_id") and m.run_id), None)
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
run = await wait_for_run_completion(client, run.id)
messages = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages if m.message_type != "user_message"]
run = await wait_for_run_completion(client, run.id, timeout=60.0)
messages_page = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages_page.items if m.message_type != "user_message"]
run_id = run.id
else:
response = client.agents.messages.create_stream(
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = next((m.run_id for m in messages if hasattr(m, "run_id") and m.run_id), None)
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
if run_id is None:
runs = await client.runs.list(agent_ids=[agent_state.id])
run_id = runs.items[0].id if runs.items else None
assert_greeting_response(
messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config
messages, model_handle, model_settings, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens")
)
if "background" in send_type:
response = client.runs.stream(run_id=run_id, starting_after=0)
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages = await accumulate_chunks(response)
assert_greeting_response(
messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config
messages, model_handle, model_settings, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens")
)
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config)
messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None)
messages_from_db = messages_from_db_page.items
assert_greeting_response(messages_from_db, model_handle, model_settings, from_db=True)
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.completed
assert run.status == "completed"
@pytest.mark.skip(reason="Skipping parallel tool calling test until it is fixed")
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
"model_config",
TESTED_MODEL_CONFIGS,
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
@@ -558,28 +628,30 @@ async def test_parallel_tool_calls(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
model_config: Tuple[str, dict],
send_type: str,
) -> None:
if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
model_handle, model_settings = model_config
provider_type = model_settings.get("provider_type", "")
if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
if llm_config.model in ["gpt-5", "o3"]:
if "gpt-5" in model_handle or "o3" in model_handle:
pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.")
# change llm_config to support parallel tool calling
# Create a copy and modify it to ensure we're not modifying the original
modified_llm_config = llm_config.model_copy(deep=True)
modified_llm_config.parallel_tool_calls = True
# this test was flaking so set temperature to 0.0 to avoid randomness
modified_llm_config.temperature = 0.0
# Skip Gemini models due to issues with parallel tool calling
if provider_type in ["google_ai", "google_vertex"]:
pytest.skip("Gemini models are flaky for this test so we disable them for now")
# IMPORTANT: Set parallel_tool_calls at BOTH the agent level and llm_config level
# There are two different parallel_tool_calls fields that need to be set
agent_state = await client.agents.modify(
# Update model_settings to enable parallel tool calling
modified_model_settings = model_settings.copy()
modified_model_settings["parallel_tool_calls"] = True
agent_state = await client.agents.update(
agent_id=agent_state.id,
llm_config=modified_llm_config,
parallel_tool_calls=True, # Set at agent level as well!
model=model_handle,
model_settings=modified_model_settings,
)
if send_type == "step":
@@ -592,9 +664,9 @@ async def test_parallel_tool_calls(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
)
await wait_for_run_completion(client, run.id)
await wait_for_run_completion(client, run.id, timeout=60.0)
else:
response = client.agents.messages.create_stream(
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
stream_tokens=(send_type == "stream_tokens"),
@@ -603,53 +675,134 @@ async def test_parallel_tool_calls(
await accumulate_chunks(response)
# validate parallel tool call behavior in preserved messages
preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
preserved_messages_page = await client.agents.messages.list(agent_id=agent_state.id)
preserved_messages = preserved_messages_page.items
# find the tool call message in preserved messages
tool_call_msg = None
tool_return_msg = None
# collect all ToolCallMessage and ToolReturnMessage instances
tool_call_messages = []
tool_return_messages = []
for msg in preserved_messages:
if isinstance(msg, ToolCallMessage):
tool_call_msg = msg
tool_call_messages.append(msg)
elif isinstance(msg, ToolReturnMessage):
tool_return_msg = msg
tool_return_messages.append(msg)
# assert parallel tool calls were made
assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
# Check if tool calls are grouped in a single message (parallel) or separate messages (sequential)
total_tool_calls = 0
for i, tcm in enumerate(tool_call_messages):
if hasattr(tcm, "tool_calls") and tcm.tool_calls:
num_calls = len(tcm.tool_calls) if isinstance(tcm.tool_calls, list) else 1
total_tool_calls += num_calls
elif hasattr(tcm, "tool_call"):
total_tool_calls += 1
# verify each tool call
for tc in tool_call_msg.tool_calls:
assert tc["name"] == "roll_dice"
# Check tool returns structure
total_tool_returns = 0
for i, trm in enumerate(tool_return_messages):
if hasattr(trm, "tool_returns") and trm.tool_returns:
num_returns = len(trm.tool_returns) if isinstance(trm.tool_returns, list) else 1
total_tool_returns += num_returns
elif hasattr(trm, "tool_return"):
total_tool_returns += 1
# CRITICAL: For TRUE parallel tool calling with letta_v1_agent, there should be exactly ONE ToolCallMessage
# containing multiple tool calls, not multiple ToolCallMessages
# Verify we have exactly 3 tool calls total
assert total_tool_calls == 3, f"Expected exactly 3 tool calls total, got {total_tool_calls}"
assert total_tool_returns == 3, f"Expected exactly 3 tool returns total, got {total_tool_returns}"
# Check if we have true parallel tool calling
is_parallel = False
if len(tool_call_messages) == 1:
# Check if the single message contains multiple tool calls
tcm = tool_call_messages[0]
if hasattr(tcm, "tool_calls") and isinstance(tcm.tool_calls, list) and len(tcm.tool_calls) == 3:
is_parallel = True
# IMPORTANT: Assert that parallel tool calling is actually working
# This test should FAIL if parallel tool calling is not working properly
assert is_parallel, (
f"Parallel tool calling is NOT working for {provider_type}! "
f"Got {len(tool_call_messages)} ToolCallMessage(s) instead of 1 with 3 parallel calls. "
f"When using letta_v1_agent with parallel_tool_calls=True, all tool calls should be in a single message."
)
# Collect all tool calls and their details for validation
all_tool_calls = []
tool_call_ids = set()
num_sides_by_id = {}
for tcm in tool_call_messages:
if hasattr(tcm, "tool_calls") and tcm.tool_calls and isinstance(tcm.tool_calls, list):
# Message has multiple tool calls
for tc in tcm.tool_calls:
all_tool_calls.append(tc)
tool_call_ids.add(tc.tool_call_id)
# Parse arguments
import json
args = json.loads(tc.arguments)
num_sides_by_id[tc.tool_call_id] = int(args["num_sides"])
elif hasattr(tcm, "tool_call") and tcm.tool_call:
# Message has single tool call
tc = tcm.tool_call
all_tool_calls.append(tc)
tool_call_ids.add(tc.tool_call_id)
# Parse arguments
import json
args = json.loads(tc.arguments)
num_sides_by_id[tc.tool_call_id] = int(args["num_sides"])
# Verify each tool call
for tc in all_tool_calls:
assert tc.name == "roll_dice", f"Expected tool call name 'roll_dice', got '{tc.name}'"
# Support Anthropic (toolu_), OpenAI (call_), and Gemini (UUID) tool call ID formats
# Gemini uses UUID format which could start with any alphanumeric character
valid_id_format = (
tc["tool_call_id"].startswith("toolu_")
or tc["tool_call_id"].startswith("call_")
or (len(tc["tool_call_id"]) > 0 and tc["tool_call_id"][0].isalnum()) # UUID format for Gemini
tc.tool_call_id.startswith("toolu_")
or tc.tool_call_id.startswith("call_")
or (len(tc.tool_call_id) > 0 and tc.tool_call_id[0].isalnum()) # UUID format for Gemini
)
assert valid_id_format, f"Unexpected tool call ID format: {tc['tool_call_id']}"
assert "num_sides" in tc["arguments"]
assert valid_id_format, f"Unexpected tool call ID format: {tc.tool_call_id}"
# assert tool returns match the tool calls
assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
# Collect all tool returns for validation
all_tool_returns = []
for trm in tool_return_messages:
if hasattr(trm, "tool_returns") and trm.tool_returns and isinstance(trm.tool_returns, list):
# Message has multiple tool returns
all_tool_returns.extend(trm.tool_returns)
elif hasattr(trm, "tool_return") and trm.tool_return:
# Message has single tool return (create a mock object if needed)
# Since ToolReturnMessage might not have individual tool_return, check the structure
pass
# verify each tool return
tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
for tr in tool_return_msg.tool_returns:
assert tr["type"] == "tool"
assert tr["status"] == "success"
assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
# If all_tool_returns is empty, it means returns are structured differently
# Let's check the actual structure
if not all_tool_returns:
print("Note: Tool returns may be structured differently than expected")
# For now, just verify we got the right number of messages
assert len(tool_return_messages) > 0, "No tool return messages found"
# Verify tool returns if we have them in the expected format
for tr in all_tool_returns:
assert tr.type == "tool", f"Tool return type should be 'tool', got '{tr.type}'"
assert tr.status == "success", f"Tool return status should be 'success', got '{tr.status}'"
assert tr.tool_call_id in tool_call_ids, f"Tool return ID '{tr.tool_call_id}' not found in tool call IDs: {tool_call_ids}"
# Verify the dice roll result is within the valid range
dice_result = int(tr.tool_return)
expected_max = num_sides_by_id[tr.tool_call_id]
assert 1 <= dice_result <= expected_max, (
f"Dice roll result {dice_result} is not within valid range 1-{expected_max} for tool call {tr.tool_call_id}"
)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
"model_config",
TESTED_MODEL_CONFIGS,
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
)
@pytest.mark.parametrize(
["send_type", "cancellation"],
@@ -670,19 +823,22 @@ async def test_tool_call(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
model_config: Tuple[str, dict],
send_type: str,
cancellation: str,
) -> None:
# Skip models with OTID mismatch issues between ToolCallMessage and ToolReturnMessage
if llm_config.model == "gpt-5" or llm_config.model == "claude-sonnet-4-5-20250929" or llm_config.model.startswith("claude-opus-4-1"):
pytest.skip(f"Skipping {llm_config.model} due to OTID chain issue - messages receive incorrect OTID suffixes")
model_handle, model_settings = model_config
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Skip models with OTID mismatch issues between ToolCallMessage and ToolReturnMessage
if "gpt-5" in model_handle or "claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle:
pytest.skip(f"Skipping {model_handle} due to OTID chain issue - messages receive incorrect OTID suffixes")
last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
last_message = last_message_page.items[0] if last_message_page.items else None
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
if cancellation == "with_cancellation":
delay = 5 if llm_config.model == "gpt-5" else 0.5 # increase delay for responses api
delay = 5 if "gpt-5" in model_handle else 0.5 # increase delay for responses api
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
if send_type == "step":
@@ -691,45 +847,50 @@ async def test_tool_call(
messages=USER_MESSAGE_ROLL_DICE,
)
messages = response.messages
run_id = next((m.run_id for m in messages if hasattr(m, "run_id") and m.run_id), None)
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
run = await wait_for_run_completion(client, run.id)
messages = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages if m.message_type != "user_message"]
run = await wait_for_run_completion(client, run.id, timeout=60.0)
messages_page = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages_page.items if m.message_type != "user_message"]
run_id = run.id
else:
response = client.agents.messages.create_stream(
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = next((m.run_id for m in messages if hasattr(m, "run_id") and m.run_id), None)
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
if run_id is None:
runs = await client.runs.list(agent_ids=[agent_state.id])
run_id = runs[0].id if runs else None
run_id = runs.items[0].id if runs.items else None
assert_tool_call_response(
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
messages, model_handle, model_settings, streaming=("stream" in send_type), with_cancellation=(cancellation == "with_cancellation")
)
if "background" in send_type:
response = client.runs.stream(run_id=run_id, starting_after=0)
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages = await accumulate_chunks(response)
assert_tool_call_response(
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
messages,
model_handle,
model_settings,
streaming=("stream" in send_type),
with_cancellation=(cancellation == "with_cancellation"),
)
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None)
messages_from_db = messages_from_db_page.items
assert_tool_call_response(
messages_from_db, from_db=True, llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
messages_from_db, model_handle, model_settings, from_db=True, with_cancellation=(cancellation == "with_cancellation")
)
assert run_id is not None

View File

@@ -5,16 +5,10 @@ import time
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta_client.core.api_error import ApiError
from letta_client import APIError, Letta
from letta_client.types import CreateBlockParam, MessageCreateParam, SleeptimeManagerParam
from letta.constants import DEFAULT_HUMAN
from letta.orm.errors import NoResultFound
from letta.schemas.block import CreateBlock
from letta.schemas.enums import AgentType, JobStatus, JobType, ToolRuleType
from letta.schemas.group import ManagerType, SleeptimeManagerUpdate
from letta.schemas.message import MessageCreate
from letta.schemas.run import Run
from letta.utils import get_human_text, get_persona_text
@@ -74,11 +68,11 @@ async def test_sleeptime_group_chat(client):
main_agent = client.agents.create(
name="main_agent",
memory_blocks=[
CreateBlock(
CreateBlockParam(
label="persona",
value="You are a personal assistant that helps users with requests.",
),
CreateBlock(
CreateBlockParam(
label="human",
value="My favorite plant is the fiddle leaf\nMy favorite color is lavender",
),
@@ -86,7 +80,7 @@ async def test_sleeptime_group_chat(client):
model="anthropic/claude-sonnet-4-5-20250929",
embedding="openai/text-embedding-3-small",
enable_sleeptime=True,
agent_type=AgentType.letta_v1_agent,
agent_type="letta_v1_agent",
)
assert main_agent.enable_sleeptime == True
@@ -96,34 +90,35 @@ async def test_sleeptime_group_chat(client):
assert "archival_memory_insert" not in main_agent_tools
# 2. Override frequency for test
group = client.groups.modify(
group = client.groups.update(
group_id=main_agent.multi_agent_group.id,
manager_config=SleeptimeManagerUpdate(
manager_config=SleeptimeManagerParam(
manager_type="sleeptime",
sleeptime_agent_frequency=2,
),
)
assert group.manager_type == ManagerType.sleeptime
assert group.manager_type == "sleeptime"
assert group.sleeptime_agent_frequency == 2
assert len(group.agent_ids) == 1
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = client.agents.blocks.retrieve(agent_id=main_agent.id, block_label="human")
agents = client.blocks.agents.list(block_id=shared_block.id)
agents = client.blocks.agents.list(block_id=shared_block.id).items
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert main_agent.id in [agent.id for agent in agents]
# 4 Verify sleeptime agent tools
sleeptime_agent = client.agents.retrieve(agent_id=sleeptime_agent_id)
sleeptime_agent = client.agents.retrieve(agent_id=sleeptime_agent_id, include=["agent.tools"])
sleeptime_agent_tools = [tool.name for tool in sleeptime_agent.tools]
assert "memory_rethink" in sleeptime_agent_tools
assert "memory_finish_edits" in sleeptime_agent_tools
assert "memory_replace" in sleeptime_agent_tools
assert "memory_insert" in sleeptime_agent_tools
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.exit_loop]) > 0
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == "exit_loop"]) > 0
# 5. Send messages and verify run ids
message_text = [
@@ -139,7 +134,7 @@ async def test_sleeptime_group_chat(client):
response = client.agents.messages.create(
agent_id=main_agent.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content=text,
),
@@ -150,22 +145,21 @@ async def test_sleeptime_group_chat(client):
assert len(response.usage.run_ids or []) == (i + 1) % 2
run_ids.extend(response.usage.run_ids or [])
runs = client.runs.list()
agent_runs = [run for run in runs if run.agent_id == sleeptime_agent_id]
assert len(agent_runs) == len(run_ids)
runs = client.runs.list(agent_id=sleeptime_agent_id).items
assert len(runs) == len(run_ids)
# 6. Verify run status after sleep
time.sleep(2)
for run_id in run_ids:
job = client.runs.retrieve(run_id=run_id)
assert job.status == JobStatus.running or job.status == JobStatus.completed
assert job.status == "running" or job.status == "completed"
# 7. Delete agent
client.agents.delete(agent_id=main_agent.id)
with pytest.raises(ApiError):
with pytest.raises(APIError):
client.groups.retrieve(group_id=group.id)
with pytest.raises(ApiError):
with pytest.raises(APIError):
client.agents.retrieve(agent_id=sleeptime_agent_id)
@@ -177,11 +171,11 @@ async def test_sleeptime_removes_redundant_information(client):
main_agent = client.agents.create(
name="main_agent",
memory_blocks=[
CreateBlock(
CreateBlockParam(
label="persona",
value="You are a personal assistant that helps users with requests.",
),
CreateBlock(
CreateBlockParam(
label="human",
value="My favorite plant is the fiddle leaf\nMy favorite dog is the husky\nMy favorite plant is the fiddle leaf\nMy favorite plant is the fiddle leaf",
),
@@ -189,12 +183,13 @@ async def test_sleeptime_removes_redundant_information(client):
model="anthropic/claude-sonnet-4-5-20250929",
embedding="openai/text-embedding-3-small",
enable_sleeptime=True,
agent_type=AgentType.letta_v1_agent,
agent_type="letta_v1_agent",
)
group = client.groups.modify(
group = client.groups.update(
group_id=main_agent.multi_agent_group.id,
manager_config=SleeptimeManagerUpdate(
manager_config=SleeptimeManagerParam(
manager_type="sleeptime",
sleeptime_agent_frequency=1,
),
)
@@ -207,7 +202,7 @@ async def test_sleeptime_removes_redundant_information(client):
_ = client.agents.messages.create(
agent_id=main_agent.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content=test_message,
),
@@ -224,9 +219,9 @@ async def test_sleeptime_removes_redundant_information(client):
# 4. Delete agent
client.agents.delete(agent_id=main_agent.id)
with pytest.raises(ApiError):
with pytest.raises(APIError):
client.groups.retrieve(group_id=group.id)
with pytest.raises(ApiError):
with pytest.raises(APIError):
client.agents.retrieve(agent_id=sleeptime_agent_id)
@@ -234,19 +229,19 @@ async def test_sleeptime_removes_redundant_information(client):
async def test_sleeptime_edit(client):
sleeptime_agent = client.agents.create(
name="sleeptime_agent",
agent_type=AgentType.sleeptime_agent,
agent_type="sleeptime_agent",
memory_blocks=[
CreateBlock(
CreateBlockParam(
label="human",
value=get_human_text(DEFAULT_HUMAN),
limit=2000,
),
CreateBlock(
CreateBlockParam(
label="memory_persona",
value=get_persona_text("sleeptime_memory_persona"),
limit=2000,
),
CreateBlock(
CreateBlockParam(
label="fact_block",
value="""Messi resides in the Paris.
Messi plays in the league Ligue 1.
@@ -265,7 +260,7 @@ async def test_sleeptime_edit(client):
_ = client.agents.messages.create(
agent_id=sleeptime_agent.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content="Messi has now moved to playing for Inter Miami",
),
@@ -286,11 +281,11 @@ async def test_sleeptime_agent_new_block_attachment(client):
main_agent = client.agents.create(
name="main_agent",
memory_blocks=[
CreateBlock(
CreateBlockParam(
label="persona",
value="You are a personal assistant that helps users with requests.",
),
CreateBlock(
CreateBlockParam(
label="human",
value="My favorite plant is the fiddle leaf\nMy favorite color is lavender",
),
@@ -298,7 +293,7 @@ async def test_sleeptime_agent_new_block_attachment(client):
model="anthropic/claude-sonnet-4-5-20250929",
embedding="openai/text-embedding-3-small",
enable_sleeptime=True,
agent_type=AgentType.letta_v1_agent,
agent_type="letta_v1_agent",
)
assert main_agent.enable_sleeptime == True
@@ -308,13 +303,13 @@ async def test_sleeptime_agent_new_block_attachment(client):
sleeptime_agent_id = group.agent_ids[0]
# 3. Verify initial shared blocks
main_agent_refreshed = client.agents.retrieve(agent_id=main_agent.id)
main_agent_refreshed = client.agents.retrieve(agent_id=main_agent.id, include=["agent.blocks"])
initial_blocks = main_agent_refreshed.memory.blocks
initial_block_count = len(initial_blocks)
# Verify both agents share the initial blocks
for block in initial_blocks:
agents = client.blocks.agents.list(block_id=block.id)
agents = client.blocks.agents.list(block_id=block.id).items
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert main_agent.id in [agent.id for agent in agents]
@@ -331,14 +326,14 @@ async def test_sleeptime_agent_new_block_attachment(client):
client.agents.blocks.attach(agent_id=main_agent.id, block_id=new_block.id)
# 6. Verify the new block is attached to the main agent
main_agent_refreshed = client.agents.retrieve(agent_id=main_agent.id)
main_agent_refreshed = client.agents.retrieve(agent_id=main_agent.id, include=["agent.blocks"])
main_agent_blocks = main_agent_refreshed.memory.blocks
assert len(main_agent_blocks) == initial_block_count + 1
main_agent_block_ids = [block.id for block in main_agent_blocks]
assert new_block.id in main_agent_block_ids
# 7. Check if the new block is also attached to the sleeptime agent (this is where the bug might be)
sleeptime_agent = client.agents.retrieve(agent_id=sleeptime_agent_id)
sleeptime_agent = client.agents.retrieve(agent_id=sleeptime_agent_id, include=["agent.blocks"])
sleeptime_agent_blocks = sleeptime_agent.memory.blocks
sleeptime_agent_block_ids = [block.id for block in sleeptime_agent_blocks]
@@ -346,7 +341,7 @@ async def test_sleeptime_agent_new_block_attachment(client):
assert new_block.id in sleeptime_agent_block_ids, f"New block {new_block.id} not attached to sleeptime agent {sleeptime_agent_id}"
# 8. Verify that agents sharing the new block include both main and sleeptime agents
agents_with_new_block = client.blocks.agents.list(block_id=new_block.id)
agents_with_new_block = client.blocks.agents.list(block_id=new_block.id).items
agent_ids_with_new_block = [agent.id for agent in agents_with_new_block]
assert main_agent.id in agent_ids_with_new_block, "Main agent should have access to the new block"

View File

@@ -1,11 +1,42 @@
from conftest import create_test_module
AGENTS_CREATE_PARAMS = [
("caren_agent", {"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-3-small"}, {}, None),
(
"caren_agent",
{"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-3-small"},
{
# Verify model_settings is populated with config values
# Note: The 'model' field itself is separate from model_settings
"model_settings": {
"max_output_tokens": 4096,
"parallel_tool_calls": False,
"provider_type": "openai",
"temperature": 0.7,
"reasoning": {"reasoning_effort": "minimal"},
"response_format": None,
}
},
None,
),
]
AGENTS_MODIFY_PARAMS = [
("caren_agent", {"name": "caren_updated"}, {}, None),
AGENTS_UPDATE_PARAMS = [
(
"caren_agent",
{"name": "caren_updated"},
{
# After updating just the name, model_settings should still be present
"model_settings": {
"max_output_tokens": 4096,
"parallel_tool_calls": False,
"provider_type": "openai",
"temperature": 0.7,
"reasoning": {"reasoning_effort": "minimal"},
"response_format": None,
}
},
None,
),
]
AGENTS_LIST_PARAMS = [
@@ -19,7 +50,7 @@ globals().update(
resource_name="agents",
id_param_name="agent_id",
create_params=AGENTS_CREATE_PARAMS,
modify_params=AGENTS_MODIFY_PARAMS,
update_params=AGENTS_UPDATE_PARAMS,
list_params=AGENTS_LIST_PARAMS,
)
)

View File

@@ -1,5 +1,5 @@
from conftest import create_test_module
from letta_client.errors import UnprocessableEntityError
from letta_client import UnprocessableEntityError
from letta.constants import CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
@@ -8,7 +8,7 @@ BLOCKS_CREATE_PARAMS = [
("persona_block", {"label": "persona", "value": "test1"}, {"limit": CORE_MEMORY_PERSONA_CHAR_LIMIT}, None),
]
BLOCKS_MODIFY_PARAMS = [
BLOCKS_UPDATE_PARAMS = [
("human_block", {"value": "test2"}, {}, None),
("persona_block", {"value": "testing testing testing", "limit": 10}, {}, UnprocessableEntityError),
]
@@ -25,7 +25,7 @@ globals().update(
resource_name="blocks",
id_param_name="block_id",
create_params=BLOCKS_CREATE_PARAMS,
modify_params=BLOCKS_MODIFY_PARAMS,
update_params=BLOCKS_UPDATE_PARAMS,
list_params=BLOCKS_LIST_PARAMS,
)
)

View File

@@ -48,14 +48,12 @@ def server_url() -> str:
# This fixture creates a client for each test module
@pytest.fixture(scope="session")
def client(server_url):
print("Running client tests with server:", server_url)
# Overide the base_url if the LETTA_API_URL is set
api_url = os.getenv("LETTA_API_URL")
base_url = api_url if api_url else server_url
# create the Letta client
yield Letta(base_url=base_url, token=None, timeout=300.0)
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
def skip_test_if_not_implemented(handler, resource_name, test_name):
@@ -68,7 +66,7 @@ def create_test_module(
id_param_name: str,
create_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
upsert_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
modify_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
update_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
list_params: List[Tuple[Dict[str, Any], int]] = [],
) -> Dict[str, Any]:
"""Create a test module for a resource.
@@ -80,7 +78,7 @@ def create_test_module(
resource_name: Name of the resource (e.g., "blocks", "tools")
id_param_name: Name of the ID parameter (e.g., "block_id", "tool_id")
create_params: List of (name, params, expected_error) tuples for create tests
modify_params: List of (name, params, expected_error) tuples for modify tests
update_params: List of (name, params, expected_error) tuples for update tests
list_params: List of (query_params, expected_count) tuples for list tests
Returns:
@@ -138,11 +136,7 @@ def create_test_module(
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(item, key):
if key == "model" or key == "embedding":
# NOTE: add back these tests after v1 migration
continue
print(f"item.{key}: {getattr(item, key)}")
assert custom_model_dump(getattr(item, key)) == value, f"For key {key}, expected {value}, but got {getattr(item, key)}"
assert custom_model_dump(getattr(item, key)) == value
@pytest.mark.order(1)
def test_retrieve(handler):
@@ -180,9 +174,9 @@ def create_test_module(
assert custom_model_dump(getattr(item, key)) == value
@pytest.mark.order(3)
def test_modify(handler, caren_agent, name, params, extra_expected_values, expected_error):
"""Test modifying a resource."""
skip_test_if_not_implemented(handler, resource_name, "modify")
def test_update(handler, caren_agent, name, params, extra_expected_values, expected_error):
"""Test updating a resource."""
skip_test_if_not_implemented(handler, resource_name, "update")
if name not in test_item_ids:
pytest.skip(f"Item '{name}' not found in test_items")
@@ -192,7 +186,7 @@ def create_test_module(
processed_extra_expected = preprocess_params(extra_expected_values, caren_agent)
try:
item = handler.modify(**processed_params)
item = handler.update(**processed_params)
except Exception as e:
if expected_error is not None:
assert isinstance(e, expected_error), f"Expected error with type {expected_error}, but got {type(e)}: {e}"
@@ -254,7 +248,7 @@ def create_test_module(
"test_create": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", create_params)(test_create),
"test_retrieve": test_retrieve,
"test_upsert": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", upsert_params)(test_upsert),
"test_modify": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", modify_params)(test_modify),
"test_update": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", update_params)(test_update),
"test_delete": test_delete,
"test_list": pytest.mark.parametrize("query_params, count", list_params)(test_list),
}

View File

@@ -10,7 +10,7 @@ GROUPS_CREATE_PARAMS = [
),
]
GROUPS_MODIFY_PARAMS = [
GROUPS_UPDATE_PARAMS = [
(
"round_robin_group",
{"manager_config": {"manager_type": "round_robin", "max_turns": 10}},
@@ -30,7 +30,7 @@ globals().update(
resource_name="groups",
id_param_name="group_id",
create_params=GROUPS_CREATE_PARAMS,
modify_params=GROUPS_MODIFY_PARAMS,
update_params=GROUPS_UPDATE_PARAMS,
list_params=GROUPS_LIST_PARAMS,
)
)

View File

@@ -5,7 +5,7 @@ IDENTITIES_CREATE_PARAMS = [
("caren2", {"identifier_key": "456", "name": "caren", "identity_type": "user"}, {}, None),
]
IDENTITIES_MODIFY_PARAMS = [
IDENTITIES_UPDATE_PARAMS = [
("caren1", {"properties": [{"key": "email", "value": "caren@letta.com", "type": "string"}]}, {}, None),
("caren2", {"properties": [{"key": "email", "value": "caren@gmail.com", "type": "string"}]}, {}, None),
]
@@ -37,7 +37,7 @@ globals().update(
id_param_name="identity_id",
create_params=IDENTITIES_CREATE_PARAMS,
upsert_params=IDENTITIES_UPSERT_PARAMS,
modify_params=IDENTITIES_MODIFY_PARAMS,
update_params=IDENTITIES_UPDATE_PARAMS,
list_params=IDENTITIES_LIST_PARAMS,
)
)

View File

@@ -8,12 +8,14 @@ with Turbopuffer integration, including vector search, FTS, hybrid search, filte
import time
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
import pytest
from letta_client import Letta
from letta_client.types import CreateBlockParam, MessageCreateParam
from letta.config import LettaConfig
from letta.server.rest_api.routers.v1.passages import PassageSearchResult
from letta.server.server import SyncServer
from letta.settings import settings
@@ -147,7 +149,15 @@ def test_passage_search_basic(client: Letta, enable_turbopuffer):
time.sleep(2)
# Test search by agent_id
results = client.passages.search(query="python programming", agent_id=agent.id, limit=10)
results = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "python programming",
"agent_id": agent.id,
"limit": 10,
},
)
assert len(results) > 0, "Should find at least one passage"
assert any("Python" in result.passage.text for result in results), "Should find Python-related passage"
@@ -160,7 +170,15 @@ def test_passage_search_basic(client: Letta, enable_turbopuffer):
assert isinstance(result.score, float), "Score should be a float"
# Test search by archive_id
archive_results = client.passages.search(query="vector database", archive_id=archive.id, limit=10)
archive_results = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "vector database",
"archive_id": archive.id,
"limit": 10,
},
)
assert len(archive_results) > 0, "Should find passages in archive"
assert any("Turbopuffer" in result.passage.text or "vector" in result.passage.text for result in archive_results), (
@@ -213,7 +231,15 @@ def test_passage_search_with_tags(client: Letta, enable_turbopuffer):
time.sleep(2)
# Test basic search without tags first
results = client.passages.search(query="programming tutorial", agent_id=agent.id, limit=10)
results = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "programming tutorial",
"agent_id": agent.id,
"limit": 10,
},
)
assert len(results) > 0, "Should find passages"
@@ -267,7 +293,16 @@ def test_passage_search_with_date_filters(client: Letta, enable_turbopuffer):
now = datetime.now(timezone.utc)
start_date = now - timedelta(hours=1)
results = client.passages.search(query="AI machine learning", agent_id=agent.id, limit=10, start_date=start_date)
results = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "AI machine learning",
"agent_id": agent.id,
"limit": 10,
"start_date": start_date.isoformat(),
},
)
assert len(results) > 0, "Should find recent passages"
@@ -353,15 +388,39 @@ def test_passage_search_pagination(client: Letta, enable_turbopuffer):
time.sleep(2)
# Test with different limit values
results_limit_3 = client.passages.search(query="programming", agent_id=agent.id, limit=3)
results_limit_3 = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "programming",
"agent_id": agent.id,
"limit": 3,
},
)
assert len(results_limit_3) == 3, "Should respect limit parameter"
results_limit_5 = client.passages.search(query="programming", agent_id=agent.id, limit=5)
results_limit_5 = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "programming",
"agent_id": agent.id,
"limit": 5,
},
)
assert len(results_limit_5) == 5, "Should return 5 results"
results_all = client.passages.search(query="programming", agent_id=agent.id, limit=20)
results_all = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "programming",
"agent_id": agent.id,
"limit": 20,
},
)
assert len(results_all) >= 10, "Should return all matching passages"
@@ -414,7 +473,14 @@ def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
time.sleep(2)
# Test org-wide search (no agent_id or archive_id)
results = client.passages.search(query="unique passage", limit=20)
results = client.post(
"/v1/passages/search",
cast_to=list[PassageSearchResult],
body={
"query": "unique passage",
"limit": 20,
},
)
# Should find passages from both agents
assert len(results) >= 2, "Should find passages from multiple agents"

View File

@@ -44,14 +44,14 @@ TOOLS_UPSERT_PARAMS = [
("unfriendly_func", {"source_code": UNFRIENDLY_FUNC_SOURCE_CODE_V2}, {}, None),
]
TOOLS_MODIFY_PARAMS = [
TOOLS_UPDATE_PARAMS = [
("friendly_func", {"tags": ["sdk_test"]}, {}, None),
("unfriendly_func", {"return_char_limit": 300}, {}, None),
]
TOOLS_LIST_PARAMS = [
({}, 2),
({"name": ["friendly_func"]}, 1),
({"name": "friendly_func"}, 1),
]
# Create all test module components at once
@@ -61,7 +61,7 @@ globals().update(
id_param_name="tool_id",
create_params=TOOLS_CREATE_PARAMS,
upsert_params=TOOLS_UPSERT_PARAMS,
modify_params=TOOLS_MODIFY_PARAMS,
update_params=TOOLS_UPDATE_PARAMS,
list_params=TOOLS_LIST_PARAMS,
)
)

View File

@@ -1,56 +0,0 @@
from conftest import create_test_module
AGENTS_CREATE_PARAMS = [
(
"caren_agent",
{"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-3-small"},
{
# Verify model_settings is populated with config values
# Note: The 'model' field itself is separate from model_settings
"model_settings": {
"max_output_tokens": 4096,
"parallel_tool_calls": False,
"provider_type": "openai",
"temperature": 0.7,
"reasoning": {"reasoning_effort": "minimal"},
"response_format": None,
}
},
None,
),
]
AGENTS_UPDATE_PARAMS = [
(
"caren_agent",
{"name": "caren_updated"},
{
# After updating just the name, model_settings should still be present
"model_settings": {
"max_output_tokens": 4096,
"parallel_tool_calls": False,
"provider_type": "openai",
"temperature": 0.7,
"reasoning": {"reasoning_effort": "minimal"},
"response_format": None,
}
},
None,
),
]
AGENTS_LIST_PARAMS = [
({}, 1),
({"name": "caren_updated"}, 1),
]
# Create all test module components at once
globals().update(
create_test_module(
resource_name="agents",
id_param_name="agent_id",
create_params=AGENTS_CREATE_PARAMS,
update_params=AGENTS_UPDATE_PARAMS,
list_params=AGENTS_LIST_PARAMS,
)
)

View File

@@ -1,31 +0,0 @@
from conftest import create_test_module
from letta_client import UnprocessableEntityError
from letta.constants import CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
BLOCKS_CREATE_PARAMS = [
("human_block", {"label": "human", "value": "test"}, {"limit": CORE_MEMORY_HUMAN_CHAR_LIMIT}, None),
("persona_block", {"label": "persona", "value": "test1"}, {"limit": CORE_MEMORY_PERSONA_CHAR_LIMIT}, None),
]
BLOCKS_UPDATE_PARAMS = [
("human_block", {"value": "test2"}, {}, None),
("persona_block", {"value": "testing testing testing", "limit": 10}, {}, UnprocessableEntityError),
]
BLOCKS_LIST_PARAMS = [
({}, 2),
({"label": "human"}, 1),
({"label": "persona"}, 1),
]
# Create all test module components at once
globals().update(
create_test_module(
resource_name="blocks",
id_param_name="block_id",
create_params=BLOCKS_CREATE_PARAMS,
update_params=BLOCKS_UPDATE_PARAMS,
list_params=BLOCKS_LIST_PARAMS,
)
)

View File

@@ -1,319 +0,0 @@
import os
import threading
import time
from typing import Any, Dict, List, Optional, Tuple
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
@pytest.fixture(scope="session")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
# This fixture creates a client for each test module
@pytest.fixture(scope="session")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
def skip_test_if_not_implemented(handler, resource_name, test_name):
if not hasattr(handler, test_name):
pytest.skip(f"client.{resource_name}.{test_name} not implemented")
def create_test_module(
resource_name: str,
id_param_name: str,
create_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
upsert_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
update_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
list_params: List[Tuple[Dict[str, Any], int]] = [],
) -> Dict[str, Any]:
"""Create a test module for a resource.
This function creates all the necessary test methods and returns them in a dictionary
that can be added to the globals() of the module.
Args:
resource_name: Name of the resource (e.g., "blocks", "tools")
id_param_name: Name of the ID parameter (e.g., "block_id", "tool_id")
create_params: List of (name, params, expected_error) tuples for create tests
update_params: List of (name, params, expected_error) tuples for update tests
list_params: List of (query_params, expected_count) tuples for list tests
Returns:
Dict: A dictionary of all test functions that should be added to the module globals
"""
# Create shared test state
test_item_ids = {}
# Create fixture functions
@pytest.fixture(scope="session")
def handler(client):
return getattr(client, resource_name)
@pytest.fixture(scope="session")
def caren_agent(client, request):
"""Create an agent to be used as manager in supervisor groups."""
agent = client.agents.create(
name="caren_agent",
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
)
# Add finalizer to ensure cleanup happens in the right order
request.addfinalizer(lambda: client.agents.delete(agent_id=agent.id))
return agent
# Create standalone test functions
@pytest.mark.order(0)
def test_create(handler, caren_agent, name, params, extra_expected_values, expected_error):
"""Test creating a resource."""
skip_test_if_not_implemented(handler, resource_name, "create")
# Use preprocess_params which adds fixtures
processed_params = preprocess_params(params, caren_agent)
processed_extra_expected = preprocess_params(extra_expected_values, caren_agent)
try:
item = handler.create(**processed_params)
except Exception as e:
if expected_error is not None:
if hasattr(e, "status_code"):
assert e.status_code == expected_error
elif hasattr(e, "status"):
assert e.status == expected_error
else:
pytest.fail(f"Expected error with status {expected_error}, but got {type(e)}: {e}")
else:
raise e
# Store item ID for later tests
test_item_ids[name] = item.id
# Verify item properties
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(item, key):
assert custom_model_dump(getattr(item, key)) == value
@pytest.mark.order(1)
def test_retrieve(handler):
"""Test retrieving resources."""
skip_test_if_not_implemented(handler, resource_name, "retrieve")
for name, item_id in test_item_ids.items():
kwargs = {id_param_name: item_id}
item = handler.retrieve(**kwargs)
assert hasattr(item, "id") and item.id == item_id, f"{resource_name.capitalize()} {name} with id {item_id} not found"
@pytest.mark.order(2)
def test_upsert(handler, name, params, extra_expected_values, expected_error):
"""Test upserting resources."""
skip_test_if_not_implemented(handler, resource_name, "upsert")
existing_item_id = test_item_ids[name]
try:
item = handler.upsert(**params)
except Exception as e:
if expected_error is not None:
if hasattr(e, "status_code"):
assert e.status_code == expected_error
elif hasattr(e, "status"):
assert e.status == expected_error
else:
pytest.fail(f"Expected error with status {expected_error}, but got {type(e)}: {e}")
else:
raise e
assert existing_item_id == item.id
# Verify item properties
expected_values = params | extra_expected_values
for key, value in expected_values.items():
if hasattr(item, key):
assert custom_model_dump(getattr(item, key)) == value
@pytest.mark.order(3)
def test_update(handler, caren_agent, name, params, extra_expected_values, expected_error):
"""Test updating a resource."""
skip_test_if_not_implemented(handler, resource_name, "update")
if name not in test_item_ids:
pytest.skip(f"Item '{name}' not found in test_items")
kwargs = {id_param_name: test_item_ids[name]}
kwargs.update(params)
processed_params = preprocess_params(kwargs, caren_agent)
processed_extra_expected = preprocess_params(extra_expected_values, caren_agent)
try:
item = handler.update(**processed_params)
except Exception as e:
if expected_error is not None:
assert isinstance(e, expected_error), f"Expected error with type {expected_error}, but got {type(e)}: {e}"
return
else:
raise e
# Verify item properties
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(item, key):
assert custom_model_dump(getattr(item, key)) == value
# Verify via retrieve as well
retrieve_kwargs = {id_param_name: item.id}
retrieved_item = handler.retrieve(**retrieve_kwargs)
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(retrieved_item, key):
assert custom_model_dump(getattr(retrieved_item, key)) == value
@pytest.mark.order(4)
def test_list(handler, query_params, count):
"""Test listing resources."""
skip_test_if_not_implemented(handler, resource_name, "list")
all_items = handler.list(**query_params)
test_items_list = [item.id for item in all_items if item.id in test_item_ids.values()]
assert len(test_items_list) == count
@pytest.mark.order(-1)
def test_delete(handler):
"""Test deleting resources."""
skip_test_if_not_implemented(handler, resource_name, "delete")
for item_id in test_item_ids.values():
kwargs = {id_param_name: item_id}
handler.delete(**kwargs)
for name, item_id in test_item_ids.items():
try:
kwargs = {id_param_name: item_id}
item = handler.retrieve(**kwargs)
raise AssertionError(f"{resource_name.capitalize()} {name} with id {item.id} was not deleted")
except Exception as e:
if isinstance(e, AssertionError):
raise e
if hasattr(e, "status_code"):
assert e.status_code == 404, f"Expected 404 error, got {e.status_code}"
else:
raise AssertionError(f"Unexpected error type: {type(e)}")
test_item_ids.clear()
# Create test methods dictionary
result = {
"handler": handler,
"caren_agent": caren_agent,
"test_create": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", create_params)(test_create),
"test_retrieve": test_retrieve,
"test_upsert": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", upsert_params)(test_upsert),
"test_update": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", update_params)(test_update),
"test_delete": test_delete,
"test_list": pytest.mark.parametrize("query_params, count", list_params)(test_list),
}
return result
def custom_model_dump(model):
"""
Dumps the given model to a form that can be easily compared.
Args:
model: The model to dump
Returns:
The dumped model
"""
if isinstance(model, (str, int, float, bool, type(None))):
return model
if isinstance(model, list):
return [custom_model_dump(item) for item in model]
if isinstance(model, dict):
return {key: custom_model_dump(value) for key, value in model.items()}
else:
return model.model_dump()
def add_fixture_params(value, caren_agent):
"""
Replaces string values containing '.id' with their mapped values.
Args:
value: The value to process (should be a string)
caren_agent: The agent object to use for ID replacement
Returns:
The processed value with ID strings replaced by actual values
"""
param_to_fixture_mapping = {
"caren_agent.id": caren_agent.id,
}
return param_to_fixture_mapping.get(value, value)
def preprocess_params(params, caren_agent):
"""
Recursively processes a nested structure of dictionaries and lists,
replacing string values containing '.id' with their mapped values.
Args:
params: The parameters to process (dict, list, or scalar value)
caren_agent: The agent object to use for ID replacement
Returns:
The processed parameters with ID strings replaced by actual values
"""
if isinstance(params, dict):
# Process each key-value pair in the dictionary
return {key: preprocess_params(value, caren_agent) for key, value in params.items()}
elif isinstance(params, list):
# Process each item in the list
return [preprocess_params(item, caren_agent) for item in params]
elif isinstance(params, str) and ".id" in params:
# Replace string values containing '.id' with their mapped values
return add_fixture_params(params, caren_agent)
else:
# Return other values unchanged
return params

View File

@@ -1,36 +0,0 @@
from conftest import create_test_module
GROUPS_CREATE_PARAMS = [
("round_robin_group", {"agent_ids": [], "description": ""}, {"manager_type": "round_robin"}, None),
(
"supervisor_group",
{"agent_ids": [], "description": "", "manager_config": {"manager_type": "supervisor", "manager_agent_id": "caren_agent.id"}},
{"manager_type": "supervisor"},
None,
),
]
GROUPS_UPDATE_PARAMS = [
(
"round_robin_group",
{"manager_config": {"manager_type": "round_robin", "max_turns": 10}},
{"manager_type": "round_robin", "max_turns": 10},
None,
),
]
GROUPS_LIST_PARAMS = [
({}, 2),
({"manager_type": "round_robin"}, 1),
]
# Create all test module components at once
globals().update(
create_test_module(
resource_name="groups",
id_param_name="group_id",
create_params=GROUPS_CREATE_PARAMS,
update_params=GROUPS_UPDATE_PARAMS,
list_params=GROUPS_LIST_PARAMS,
)
)

View File

@@ -1,43 +0,0 @@
from conftest import create_test_module
IDENTITIES_CREATE_PARAMS = [
("caren1", {"identifier_key": "123", "name": "caren", "identity_type": "user"}, {}, None),
("caren2", {"identifier_key": "456", "name": "caren", "identity_type": "user"}, {}, None),
]
IDENTITIES_UPDATE_PARAMS = [
("caren1", {"properties": [{"key": "email", "value": "caren@letta.com", "type": "string"}]}, {}, None),
("caren2", {"properties": [{"key": "email", "value": "caren@gmail.com", "type": "string"}]}, {}, None),
]
IDENTITIES_UPSERT_PARAMS = [
(
"caren2",
{
"identifier_key": "456",
"name": "caren",
"identity_type": "user",
"properties": [{"key": "email", "value": "caren@yahoo.com", "type": "string"}],
},
{},
None,
),
]
IDENTITIES_LIST_PARAMS = [
({}, 2),
({"name": "caren"}, 2),
({"identifier_key": "123"}, 1),
]
# Create all test module components at once
globals().update(
create_test_module(
resource_name="identities",
id_param_name="identity_id",
create_params=IDENTITIES_CREATE_PARAMS,
upsert_params=IDENTITIES_UPSERT_PARAMS,
update_params=IDENTITIES_UPDATE_PARAMS,
list_params=IDENTITIES_LIST_PARAMS,
)
)

View File

@@ -1,313 +0,0 @@
import json
import os
import threading
import time
import uuid
from unittest.mock import MagicMock, patch
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
yield url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
Uses system-level EXA_API_KEY setting.
"""
client.tools.upsert_base_tools()
send_message_tool = list(client.tools.list(name="send_message"))[0]
run_code_tool = list(client.tools.list(name="run_code"))[0]
web_search_tool = list(client.tools.list(name="web_search"))[0]
agent_state_instance = client.agents.create(
name="test_builtin_tools_agent",
include_base_tools=False,
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["test_builtin_tools_agent"],
)
yield agent_state_instance
# ------------------------------
# Helper Functions and Constants
# ------------------------------
USER_MESSAGE_OTID = str(uuid.uuid4())
TEST_LANGUAGES = ["Python", "Javascript", "Typescript"]
EXPECTED_INTEGER_PARTITION_OUTPUT = "190569292"
# Reference implementation in Python, to embed in the user prompt
REFERENCE_CODE = """\
def reference_partition(n):
partitions = [1] + [0] * (n + 1)
for k in range(1, n + 1):
for i in range(k, n + 1):
partitions[i] += partitions[i - k]
return partitions[n]
"""
def reference_partition(n: int) -> int:
# Same logic, used to compute expected result in the test
partitions = [1] + [0] * (n + 1)
for k in range(1, n + 1):
for i in range(k, n + 1):
partitions[i] += partitions[i - k]
return partitions[n]
# ------------------------------
# Test Cases
# ------------------------------
@pytest.mark.parametrize("language", TEST_LANGUAGES, ids=TEST_LANGUAGES)
def test_run_code(
client: Letta,
agent_state: AgentState,
language: str,
) -> None:
"""
Sends a reference Python implementation, asks the model to translate & run it
in different languages, and verifies the exact partition(100) result.
"""
expected = str(reference_partition(100))
user_message = MessageCreateParam(
role="user",
content=(
"Here is a Python reference implementation:\n\n"
f"{REFERENCE_CODE}\n"
f"Please translate and execute this code in {language} to compute p(100), "
"and return **only** the result with no extra formatting."
),
otid=USER_MESSAGE_OTID,
)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[user_message],
)
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert tool_returns, f"No ToolReturnMessage found for language: {language}"
returns = [m.tool_return for m in tool_returns]
assert any(expected in ret for ret in returns), (
f"For language={language!r}, expected to find '{expected}' in tool_return, but got {returns!r}"
)
@pytest.mark.asyncio(scope="function")
async def test_web_search() -> None:
"""Test web search tool with mocked Exa API."""
# create mock agent state with exa api key
mock_agent_state = MagicMock()
mock_agent_state.get_agent_env_vars_as_dict.return_value = {"EXA_API_KEY": "test-exa-key"}
# Mock Exa search result with education information
mock_exa_result = MagicMock()
mock_exa_result.results = [
MagicMock(
title="Charles Packer - UC Berkeley PhD in Computer Science",
url="https://example.com/charles-packer-profile",
published_date="2023-01-01",
author="UC Berkeley",
text=None,
highlights=["Charles Packer completed his PhD at UC Berkeley", "Research in artificial intelligence and machine learning"],
summary="Charles Packer is the CEO of Letta who earned his PhD in Computer Science from UC Berkeley, specializing in AI research.",
),
MagicMock(
title="Letta Leadership Team",
url="https://letta.com/team",
published_date="2023-06-01",
author="Letta",
text=None,
highlights=["CEO Charles Packer brings academic expertise"],
summary="Leadership team page featuring CEO Charles Packer's educational background.",
),
]
with patch("exa_py.Exa") as mock_exa_class:
# Setup mock
mock_exa_client = MagicMock()
mock_exa_class.return_value = mock_exa_client
mock_exa_client.search_and_contents.return_value = mock_exa_result
# create executor with mock dependencies
executor = LettaBuiltinToolExecutor(
message_manager=MagicMock(),
agent_manager=MagicMock(),
block_manager=MagicMock(),
run_manager=MagicMock(),
passage_manager=MagicMock(),
actor=MagicMock(),
)
# call web_search directly
result = await executor.web_search(
agent_state=mock_agent_state,
query="where did Charles Packer, CEO of Letta, go to school",
num_results=10,
include_text=False,
)
# Parse the JSON response from web_search
response_json = json.loads(result)
# Basic structure assertions for new Exa format
assert "query" in response_json, "Missing 'query' field in response"
assert "results" in response_json, "Missing 'results' field in response"
# Verify we got search results
results = response_json["results"]
assert len(results) == 2, "Should have found exactly 2 search results from mock"
# Check each result has the expected structure
found_education_info = False
for result in results:
assert "title" in result, "Result missing title"
assert "url" in result, "Result missing URL"
# text should not be present since include_text=False by default
assert "text" not in result or result["text"] is None, "Text should not be included by default"
# Check for education-related information in summary and highlights
result_text = ""
if "summary" in result and result["summary"]:
result_text += " " + result["summary"].lower()
if "highlights" in result and result["highlights"]:
for highlight in result["highlights"]:
result_text += " " + highlight.lower()
# Look for education keywords
if any(keyword in result_text for keyword in ["berkeley", "university", "phd", "ph.d", "education", "student"]):
found_education_info = True
assert found_education_info, "Should have found education-related information about Charles Packer"
# Verify Exa was called with correct parameters
mock_exa_class.assert_called_once_with(api_key="test-exa-key")
mock_exa_client.search_and_contents.assert_called_once()
call_args = mock_exa_client.search_and_contents.call_args
assert call_args[1]["type"] == "auto"
assert call_args[1]["text"] is False # Default is False now
@pytest.mark.asyncio(scope="function")
async def test_web_search_uses_exa():
"""Test that web search uses Exa API correctly."""
# create mock agent state with exa api key
mock_agent_state = MagicMock()
mock_agent_state.get_agent_env_vars_as_dict.return_value = {"EXA_API_KEY": "test-exa-key"}
# Mock exa search result
mock_exa_result = MagicMock()
mock_exa_result.results = [
MagicMock(
title="Test Result",
url="https://example.com/test",
published_date="2023-01-01",
author="Test Author",
text="This is test content from the search result.",
highlights=["This is a highlight"],
summary="This is a summary of the content.",
)
]
with patch("exa_py.Exa") as mock_exa_class:
# Mock Exa
mock_exa_client = MagicMock()
mock_exa_class.return_value = mock_exa_client
mock_exa_client.search_and_contents.return_value = mock_exa_result
# create executor with mock dependencies
executor = LettaBuiltinToolExecutor(
message_manager=MagicMock(),
agent_manager=MagicMock(),
block_manager=MagicMock(),
run_manager=MagicMock(),
passage_manager=MagicMock(),
actor=MagicMock(),
)
result = await executor.web_search(agent_state=mock_agent_state, query="test query", num_results=3, include_text=True)
# Verify Exa was called correctly
mock_exa_class.assert_called_once_with(api_key="test-exa-key")
mock_exa_client.search_and_contents.assert_called_once()
# Check the call arguments
call_args = mock_exa_client.search_and_contents.call_args
assert call_args[1]["query"] == "test query"
assert call_args[1]["num_results"] == 3
assert call_args[1]["type"] == "auto"
assert call_args[1]["text"] == True
# Verify the response format
response_json = json.loads(result)
assert "query" in response_json
assert "results" in response_json
assert response_json["query"] == "test query"
assert len(response_json["results"]) == 1

File diff suppressed because it is too large Load Diff

View File

@@ -1,375 +0,0 @@
import ast
import json
import os
import threading
import time
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta_client.types.agents import SystemMessage
from tests.helpers.utils import retry_until_success
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
yield url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
client_instance.tools.upsert_base_tools()
yield client_instance
@pytest.fixture(autouse=True)
def remove_stale_agents(client):
stale_agents = client.agents.list(limit=300)
for agent in stale_agents:
client.agents.delete(agent_id=agent.id)
@pytest.fixture(scope="function")
def agent_obj(client: Letta) -> AgentState:
"""Create a test agent that we can call functions on"""
send_message_to_agent_tool = list(client.tools.list(name="send_message_to_agent_and_wait_for_reply"))[0]
agent_state_instance = client.agents.create(
agent_type="letta_v1_agent",
include_base_tools=True,
tool_ids=[send_message_to_agent_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
context_window_limit=32000,
)
yield agent_state_instance
@pytest.fixture(scope="function")
def other_agent_obj(client: Letta) -> AgentState:
"""Create another test agent that we can call functions on"""
agent_state_instance = client.agents.create(
agent_type="letta_v1_agent",
include_base_tools=True,
include_multi_agent_tools=False,
model="openai/gpt-4o",
embedding="letta/letta-free",
context_window_limit=32000,
)
yield agent_state_instance
@pytest.fixture
def roll_dice_tool(client: Letta):
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
return "Rolled a 5!"
# Use SDK method to create tool from function
tool = client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
yield tool
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agent(client: Letta, agent_obj: AgentState, other_agent_obj: AgentState):
secret_word = "banana"
# Encourage the agent to send a message to the other agent_obj with the secret string
response = client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
MessageCreateParam(
role="user",
content=f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!",
)
],
)
# Get messages from the other agent
messages_page = client.agents.messages.list(agent_id=other_agent_obj.id)
messages = messages_page.items
# Check for the presence of system message with secret word
found_secret = False
for m in reversed(messages):
print(f"\n\n {other_agent_obj.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
if secret_word in m.content:
found_secret = True
break
assert found_secret, f"Secret word '{secret_word}' not found in system messages of agent {other_agent_obj.id}"
# Search the sender agent for the response from another agent
in_context_messages_page = client.agents.messages.list(agent_id=agent_obj.id)
in_context_messages = in_context_messages_page.items
found = False
target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': ["
for m in in_context_messages:
# Check ToolReturnMessage for the response
if isinstance(m, ToolReturnMessage):
if target_snippet in m.tool_return:
found = True
break
# Handle different message content structures
elif hasattr(m, "content"):
if isinstance(m.content, list) and len(m.content) > 0:
content_text = m.content[0].text if hasattr(m.content[0], "text") else str(m.content[0])
else:
content_text = str(m.content)
if target_snippet in content_text:
found = True
break
if not found:
# Print debug info
joined = "\n".join(
[
str(
m.content[0].text
if hasattr(m, "content") and isinstance(m.content, list) and len(m.content) > 0 and hasattr(m.content[0], "text")
else m.content
if hasattr(m, "content")
else f"<{type(m).__name__}>"
)
for m in in_context_messages[1:]
]
)
print(f"In context messages of the sender agent (without system):\n\n{joined}")
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
MessageCreateParam(
role="user",
content="So what did the other agent say?",
)
],
)
print(response.messages)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client: Letta):
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
secret_word = "banana"
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = list(client.tools.list(name="send_message_to_agents_matching_tags"))[0].id
manager_agent_state = client.agents.create(
agent_type="letta_v1_agent",
name="manager_agent",
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 2 non-matching worker agents (These should NOT get the message)
worker_agents_123 = []
for idx in range(2):
worker_agent_state = client.agents.create(
agent_type="letta_v1_agent",
name=f"not_worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_123,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents_123.append(worker_agent_state)
# Create 2 worker agents that should get the message
worker_agents_456 = []
for idx in range(2):
worker_agent_state = client.agents.create(
agent_type="letta_v1_agent",
name=f"worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_456,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents_456.append(worker_agent_state)
# Encourage the manager to send a message to the other agent_obj with the secret string
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
MessageCreateParam(
role="user",
content=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
)
],
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = ast.literal_eval(m.tool_return)
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents_456)
# Verify responses from all expected worker agents
worker_agent_ids = {agent.id for agent in worker_agents_456}
returned_agent_ids = set()
for json_str in tool_response:
response_obj = json.loads(json_str)
assert response_obj["agent_id"] in worker_agent_ids
assert response_obj["response_messages"] != ["<no response>"]
returned_agent_ids.add(response_obj["agent_id"])
break
# Check messages in the worker agents that should have received the message
for agent_state in worker_agents_456:
messages_page = client.agents.messages.list(agent_state.id)
messages = messages_page.items
# Check for the presence of system message
found_secret = False
for m in reversed(messages):
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
if secret_word in m.content:
found_secret = True
break
assert found_secret, f"Secret word not found in messages for agent {agent_state.id}"
# Ensure it's NOT in the non matching worker agents
for agent_state in worker_agents_123:
messages_page = client.agents.messages.list(agent_state.id)
messages = messages_page.items
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word not in m.content, f"Secret word should not be in agent {agent_state.id}"
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
MessageCreateParam(
role="user",
content="So what did the other agent say?",
)
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_complex_tool_use(client: Letta, roll_dice_tool):
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = list(client.tools.list(name="send_message_to_agents_matching_tags"))[0].id
manager_agent_state = client.agents.create(
agent_type="letta_v1_agent",
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 2 worker agents
worker_agents = []
worker_tags = ["dice-rollers"]
for _ in range(2):
worker_agent_state = client.agents.create(
agent_type="letta_v1_agent",
include_multi_agent_tools=False,
tags=worker_tags,
tool_ids=[roll_dice_tool.id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents.append(worker_agent_state)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
MessageCreateParam(
role="user",
content=broadcast_message,
)
],
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
# Parse tool_return string to get list of responses
tool_response = ast.literal_eval(m.tool_return)
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents)
# Verify responses from all expected worker agents
worker_agent_ids = {agent.id for agent in worker_agents}
returned_agent_ids = set()
all_responses = []
for json_str in tool_response:
response_obj = json.loads(json_str)
assert response_obj["agent_id"] in worker_agent_ids
assert response_obj["response_messages"] != ["<no response>"]
returned_agent_ids.add(response_obj["agent_id"])
all_responses.extend(response_obj["response_messages"])
break
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
MessageCreateParam(
role="user",
content="So what did the other agent say?",
)
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))

File diff suppressed because it is too large Load Diff

View File

@@ -1,898 +0,0 @@
import asyncio
import itertools
import json
import logging
import os
import threading
import time
import uuid
from typing import Any, List, Tuple
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
from letta_client.types.agents import AssistantMessage, ReasoningMessage, Run, ToolCallMessage, UserMessage
from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics
logger = logging.getLogger(__name__)
# ------------------------------
# Helper Functions and Constants
# ------------------------------
all_configs = [
"openai-gpt-4o-mini.json",
"openai-gpt-4.1.json",
"openai-gpt-5.json",
"claude-4-5-sonnet.json",
"gemini-2.5-pro.json",
]
def get_model_config(filename: str, model_settings_dir: str = "tests/sdk_v1/model_settings") -> Tuple[str, dict]:
"""Load a model_settings file and return the handle and settings dict."""
filename = os.path.join(model_settings_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
return config_data["handle"], config_data.get("model_settings", {})
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_MODEL_CONFIGS: List[Tuple[str, dict]] = [get_model_config(fn) for fn in filenames]
def roll_dice(num_sides: int) -> int:
"""
Returns a random number between 1 and num_sides.
Args:
num_sides (int): The number of sides on the die.
Returns:
int: A random integer between 1 and num_sides, representing the die roll.
"""
import random
return random.randint(1, num_sides)
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
USER_MESSAGE_FORCE_REPLY: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=(
"This is an automated test message. Please call the roll_dice tool EXACTLY three times in parallel - no more, no less. "
"Call it with num_sides=6, num_sides=12, and num_sides=20. Make all three calls at the same time in a single response."
),
otid=USER_MESSAGE_OTID,
)
]
def assert_greeting_response(
messages: List[Any],
model_handle: str,
model_settings: dict,
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> AssistantMessage.
"""
# Filter out LettaPing messages which are keep-alive messages for SSE streams
messages = [
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
model_handle, model_settings, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
# User message if loaded from db
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(model_handle, model_settings):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning is non-deterministic, so don't throw if missing
pass
# Assistant message
assert isinstance(messages[index], AssistantMessage)
if not token_streaming:
assert "teamwork" in messages[index].content.lower()
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Stop reason and usage statistics if streaming
if streaming:
assert isinstance(messages[index], LettaStopReason)
assert messages[index].stop_reason == "end_turn"
index += 1
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
def assert_tool_call_response(
messages: List[Any],
model_handle: str,
model_settings: dict,
streaming: bool = False,
from_db: bool = False,
with_cancellation: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
ReasoningMessage -> AssistantMessage.
"""
# Filter out LettaPing messages which are keep-alive messages for SSE streams
messages = [
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
# If cancellation happened and no messages were persisted (early cancellation), return early
if with_cancellation and len(messages) == 0:
return
if not with_cancellation:
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
model_handle, model_settings, tool_call=True, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
# User message if loaded from db
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# If cancellation happened after user message but before any response, return early
if with_cancellation and index >= len(messages):
return
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(model_handle, model_settings):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning is non-deterministic, so don't throw if missing
pass
# Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call
if (
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle)
and index < len(messages)
and isinstance(messages[index], AssistantMessage)
):
# Skip the extra AssistantMessage and move to the next message
index += 1
otid_suffix += 1
# Tool call message (may be skipped if cancelled early)
if with_cancellation and index < len(messages) and isinstance(messages[index], AssistantMessage):
# If cancelled early, model might respond with text instead of making tool call
assert "roll" in messages[index].content.lower() or "die" in messages[index].content.lower()
return # Skip tool call assertions for early cancellation
# If cancellation happens before tool call, we might get LettaStopReason directly
if with_cancellation and index < len(messages) and isinstance(messages[index], LettaStopReason):
assert messages[index].stop_reason == "cancelled"
return # Skip remaining assertions for very early cancellation
assert isinstance(messages[index], ToolCallMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Tool return message
otid_suffix = 0
assert isinstance(messages[index], ToolReturnMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Messages from second agent step if request has not been cancelled
if not with_cancellation:
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(model_handle, model_settings):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning is non-deterministic, so don't throw if missing
pass
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Stop reason and usage statistics if streaming
if streaming:
assert isinstance(messages[index], LettaStopReason)
assert messages[index].stop_reason == ("cancelled" if with_cancellation else "end_turn")
index += 1
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
async def accumulate_chunks(chunks, verify_token_streaming: bool = False) -> List[Any]:
"""
Accumulates chunks into a list of messages.
Handles both async iterators and raw SSE strings.
"""
messages = []
current_message = None
prev_message_type = None
# Handle raw SSE string from runs.messages.stream()
if isinstance(chunks, str):
import json
for line in chunks.strip().split("\n"):
if line.startswith("data: ") and line != "data: [DONE]":
try:
data = json.loads(line[6:]) # Remove 'data: ' prefix
if "message_type" in data:
# Create proper message type objects
message_type = data.get("message_type")
if message_type == "assistant_message":
from letta_client.types.agents import AssistantMessage
chunk = AssistantMessage(**data)
elif message_type == "reasoning_message":
from letta_client.types.agents import ReasoningMessage
chunk = ReasoningMessage(**data)
elif message_type == "tool_call_message":
from letta_client.types.agents import ToolCallMessage
chunk = ToolCallMessage(**data)
elif message_type == "tool_return_message":
from letta_client.types import ToolReturnMessage
chunk = ToolReturnMessage(**data)
elif message_type == "user_message":
from letta_client.types.agents import UserMessage
chunk = UserMessage(**data)
elif message_type == "stop_reason":
from letta_client.types.agents.letta_streaming_response import LettaStopReason
chunk = LettaStopReason(**data)
elif message_type == "usage_statistics":
from letta_client.types.agents.letta_streaming_response import LettaUsageStatistics
chunk = LettaUsageStatistics(**data)
else:
chunk = type("Chunk", (), data)() # Fallback for unknown types
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
if current_message is not None:
messages.append(current_message)
current_message = chunk
else:
# Accumulate content for same message type
if hasattr(current_message, "content") and hasattr(chunk, "content"):
current_message.content += chunk.content
prev_message_type = current_message_type
except json.JSONDecodeError:
continue
if current_message is not None:
messages.append(current_message)
else:
# Handle async iterator from agents.messages.stream()
async for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
if current_message is not None:
messages.append(current_message)
current_message = chunk
else:
# Accumulate content for same message type
if hasattr(current_message, "content") and hasattr(chunk, "content"):
current_message.content += chunk.content
prev_message_type = current_message_type
if current_message is not None:
messages.append(current_message)
return messages
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5):
await asyncio.sleep(delay)
await client.agents.messages.cancel(agent_id=agent_id)
async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
run = await client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "cancelled":
time.sleep(5)
return run
if run.status == "failed":
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
time.sleep(interval)
def get_expected_message_count_range(
model_handle: str, model_settings: dict, tool_call: bool = False, streaming: bool = False, from_db: bool = False
) -> Tuple[int, int]:
"""
Returns the expected range of number of messages for a given LLM configuration. Uses range to account for possible variations in the number of reasoning messages.
Greeting:
------------------------------------------------------------------------------------------------------------------------------------------------------------------
| gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
| AssistantMessage | AssistantMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage |
| | | AssistantMessage | | AssistantMessage | AssistantMessage |
Tool Call:
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | sonnet-4.5/opus-4.1 | flash-2.5-thinking |
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
| ToolCallMessage | ToolCallMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage | ReasoningMessage |
| ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | AssistantMessage | ToolCallMessage |
| AssistantMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | ToolReturnMessage |
| | | ReasoningMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ReasoningMessage |
| | | AssistantMessage | | AssistantMessage | ReasoningMessage | AssistantMessage |
| | | | | | AssistantMessage | |
"""
# assistant message
expected_message_count = 1
expected_range = 0
if is_reasoner_model(model_handle, model_settings):
# reasoning message
expected_range += 1
if tool_call:
# check for sonnet 4.5 or opus 4.1 specifically
is_sonnet_4_5_or_opus_4_1 = (
model_settings.get("provider_type") == "anthropic"
and model_settings.get("thinking", {}).get("type") == "enabled"
and ("claude-sonnet-4-5" in model_handle or "claude-opus-4-1" in model_handle)
)
is_anthropic_reasoning = (
model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled"
)
if is_sonnet_4_5_or_opus_4_1 or not is_anthropic_reasoning:
# sonnet 4.5 and opus 4.1 return a reasoning message before the final assistant message
# so do the other native reasoning models
expected_range += 1
# opus 4.1 generates an extra AssistantMessage before the tool call
if "claude-opus-4-1" in model_handle:
expected_range += 1
if tool_call:
# tool call and tool return messages
expected_message_count += 2
if from_db:
# user message
expected_message_count += 1
if streaming:
# stop reason and usage statistics
expected_message_count += 2
return expected_message_count, expected_message_count + expected_range
def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
"""Check if the model is a reasoning model based on its handle and settings."""
# OpenAI reasoning models with high reasoning effort
is_openai_reasoning = (
model_settings.get("provider_type") == "openai"
and (
"gpt-5" in model_handle
or "o1" in model_handle
or "o3" in model_handle
or "o4-mini" in model_handle
or "gpt-4.1" in model_handle
)
and model_settings.get("reasoning", {}).get("reasoning_effort") == "high"
)
# Anthropic models with thinking enabled
is_anthropic_reasoning = (
model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled"
)
# Google Vertex models with thinking config
is_google_vertex_reasoning = (
model_settings.get("provider_type") == "google_vertex" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
)
# Google AI models with thinking config
is_google_ai_reasoning = (
model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
)
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="function")
async def client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
client_instance = AsyncLetta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
async def agent_state(client: AsyncLetta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
"""
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
agent_state_instance = await client.agents.create(
agent_type="letta_v1_agent",
name="test_agent",
include_base_tools=False,
tool_ids=[dice_tool.id],
model="openai/gpt-4o",
embedding="openai/text-embedding-3-small",
tags=["test"],
)
yield agent_state_instance
await client.agents.delete(agent_state_instance.id)
# ------------------------------
# Test Cases
# ------------------------------
@pytest.mark.parametrize(
"model_config",
TESTED_MODEL_CONFIGS,
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
async def test_greeting(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
model_config: Tuple[str, dict],
send_type: str,
) -> None:
model_handle, model_settings = model_config
last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
last_message = last_message_page.items[0] if last_message_page.items else None
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
if send_type == "step":
response = await client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = response.messages
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
run = await wait_for_run_completion(client, run.id, timeout=60.0)
messages_page = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages_page.items if m.message_type != "user_message"]
run_id = run.id
else:
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
if run_id is None:
runs = await client.runs.list(agent_ids=[agent_state.id])
run_id = runs.items[0].id if runs.items else None
assert_greeting_response(
messages, model_handle, model_settings, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens")
)
if "background" in send_type:
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages = await accumulate_chunks(response)
assert_greeting_response(
messages, model_handle, model_settings, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens")
)
messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None)
messages_from_db = messages_from_db_page.items
assert_greeting_response(messages_from_db, model_handle, model_settings, from_db=True)
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == "completed"
@pytest.mark.parametrize(
"model_config",
TESTED_MODEL_CONFIGS,
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
async def test_parallel_tool_calls(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
model_config: Tuple[str, dict],
send_type: str,
) -> None:
model_handle, model_settings = model_config
provider_type = model_settings.get("provider_type", "")
if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
if "gpt-5" in model_handle or "o3" in model_handle:
pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.")
# Skip Gemini models due to issues with parallel tool calling
if provider_type in ["google_ai", "google_vertex"]:
pytest.skip("Gemini models are flaky for this test so we disable them for now")
# Update model_settings to enable parallel tool calling
modified_model_settings = model_settings.copy()
modified_model_settings["parallel_tool_calls"] = True
agent_state = await client.agents.update(
agent_id=agent_state.id,
model=model_handle,
model_settings=modified_model_settings,
)
if send_type == "step":
await client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
)
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
)
await wait_for_run_completion(client, run.id, timeout=60.0)
else:
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
await accumulate_chunks(response)
# validate parallel tool call behavior in preserved messages
preserved_messages_page = await client.agents.messages.list(agent_id=agent_state.id)
preserved_messages = preserved_messages_page.items
# collect all ToolCallMessage and ToolReturnMessage instances
tool_call_messages = []
tool_return_messages = []
for msg in preserved_messages:
if isinstance(msg, ToolCallMessage):
tool_call_messages.append(msg)
elif isinstance(msg, ToolReturnMessage):
tool_return_messages.append(msg)
# Check if tool calls are grouped in a single message (parallel) or separate messages (sequential)
total_tool_calls = 0
for i, tcm in enumerate(tool_call_messages):
if hasattr(tcm, "tool_calls") and tcm.tool_calls:
num_calls = len(tcm.tool_calls) if isinstance(tcm.tool_calls, list) else 1
total_tool_calls += num_calls
elif hasattr(tcm, "tool_call"):
total_tool_calls += 1
# Check tool returns structure
total_tool_returns = 0
for i, trm in enumerate(tool_return_messages):
if hasattr(trm, "tool_returns") and trm.tool_returns:
num_returns = len(trm.tool_returns) if isinstance(trm.tool_returns, list) else 1
total_tool_returns += num_returns
elif hasattr(trm, "tool_return"):
total_tool_returns += 1
# CRITICAL: For TRUE parallel tool calling with letta_v1_agent, there should be exactly ONE ToolCallMessage
# containing multiple tool calls, not multiple ToolCallMessages
# Verify we have exactly 3 tool calls total
assert total_tool_calls == 3, f"Expected exactly 3 tool calls total, got {total_tool_calls}"
assert total_tool_returns == 3, f"Expected exactly 3 tool returns total, got {total_tool_returns}"
# Check if we have true parallel tool calling
is_parallel = False
if len(tool_call_messages) == 1:
# Check if the single message contains multiple tool calls
tcm = tool_call_messages[0]
if hasattr(tcm, "tool_calls") and isinstance(tcm.tool_calls, list) and len(tcm.tool_calls) == 3:
is_parallel = True
# IMPORTANT: Assert that parallel tool calling is actually working
# This test should FAIL if parallel tool calling is not working properly
assert is_parallel, (
f"Parallel tool calling is NOT working for {provider_type}! "
f"Got {len(tool_call_messages)} ToolCallMessage(s) instead of 1 with 3 parallel calls. "
f"When using letta_v1_agent with parallel_tool_calls=True, all tool calls should be in a single message."
)
# Collect all tool calls and their details for validation
all_tool_calls = []
tool_call_ids = set()
num_sides_by_id = {}
for tcm in tool_call_messages:
if hasattr(tcm, "tool_calls") and tcm.tool_calls and isinstance(tcm.tool_calls, list):
# Message has multiple tool calls
for tc in tcm.tool_calls:
all_tool_calls.append(tc)
tool_call_ids.add(tc.tool_call_id)
# Parse arguments
import json
args = json.loads(tc.arguments)
num_sides_by_id[tc.tool_call_id] = int(args["num_sides"])
elif hasattr(tcm, "tool_call") and tcm.tool_call:
# Message has single tool call
tc = tcm.tool_call
all_tool_calls.append(tc)
tool_call_ids.add(tc.tool_call_id)
# Parse arguments
import json
args = json.loads(tc.arguments)
num_sides_by_id[tc.tool_call_id] = int(args["num_sides"])
# Verify each tool call
for tc in all_tool_calls:
assert tc.name == "roll_dice", f"Expected tool call name 'roll_dice', got '{tc.name}'"
# Support Anthropic (toolu_), OpenAI (call_), and Gemini (UUID) tool call ID formats
# Gemini uses UUID format which could start with any alphanumeric character
valid_id_format = (
tc.tool_call_id.startswith("toolu_")
or tc.tool_call_id.startswith("call_")
or (len(tc.tool_call_id) > 0 and tc.tool_call_id[0].isalnum()) # UUID format for Gemini
)
assert valid_id_format, f"Unexpected tool call ID format: {tc.tool_call_id}"
# Collect all tool returns for validation
all_tool_returns = []
for trm in tool_return_messages:
if hasattr(trm, "tool_returns") and trm.tool_returns and isinstance(trm.tool_returns, list):
# Message has multiple tool returns
all_tool_returns.extend(trm.tool_returns)
elif hasattr(trm, "tool_return") and trm.tool_return:
# Message has single tool return (create a mock object if needed)
# Since ToolReturnMessage might not have individual tool_return, check the structure
pass
# If all_tool_returns is empty, it means returns are structured differently
# Let's check the actual structure
if not all_tool_returns:
print("Note: Tool returns may be structured differently than expected")
# For now, just verify we got the right number of messages
assert len(tool_return_messages) > 0, "No tool return messages found"
# Verify tool returns if we have them in the expected format
for tr in all_tool_returns:
assert tr.type == "tool", f"Tool return type should be 'tool', got '{tr.type}'"
assert tr.status == "success", f"Tool return status should be 'success', got '{tr.status}'"
assert tr.tool_call_id in tool_call_ids, f"Tool return ID '{tr.tool_call_id}' not found in tool call IDs: {tool_call_ids}"
# Verify the dice roll result is within the valid range
dice_result = int(tr.tool_return)
expected_max = num_sides_by_id[tr.tool_call_id]
assert 1 <= dice_result <= expected_max, (
f"Dice roll result {dice_result} is not within valid range 1-{expected_max} for tool call {tr.tool_call_id}"
)
@pytest.mark.parametrize(
"model_config",
TESTED_MODEL_CONFIGS,
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
)
@pytest.mark.parametrize(
["send_type", "cancellation"],
list(
itertools.product(
["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"]
)
),
ids=[
f"{s}-{c}"
for s, c in itertools.product(
["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"]
)
],
)
@pytest.mark.asyncio(loop_scope="function")
async def test_tool_call(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
model_config: Tuple[str, dict],
send_type: str,
cancellation: str,
) -> None:
model_handle, model_settings = model_config
# Skip models with OTID mismatch issues between ToolCallMessage and ToolReturnMessage
if "gpt-5" in model_handle or "claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle:
pytest.skip(f"Skipping {model_handle} due to OTID chain issue - messages receive incorrect OTID suffixes")
last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
last_message = last_message_page.items[0] if last_message_page.items else None
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
if cancellation == "with_cancellation":
delay = 5 if "gpt-5" in model_handle else 0.5 # increase delay for responses api
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
if send_type == "step":
response = await client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
messages = response.messages
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
run = await wait_for_run_completion(client, run.id, timeout=60.0)
messages_page = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages_page.items if m.message_type != "user_message"]
run_id = run.id
else:
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
if run_id is None:
runs = await client.runs.list(agent_ids=[agent_state.id])
run_id = runs.items[0].id if runs.items else None
assert_tool_call_response(
messages, model_handle, model_settings, streaming=("stream" in send_type), with_cancellation=(cancellation == "with_cancellation")
)
if "background" in send_type:
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages = await accumulate_chunks(response)
assert_tool_call_response(
messages,
model_handle,
model_settings,
streaming=("stream" in send_type),
with_cancellation=(cancellation == "with_cancellation"),
)
messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None)
messages_from_db = messages_from_db_page.items
assert_tool_call_response(
messages_from_db, model_handle, model_settings, from_db=True, with_cancellation=(cancellation == "with_cancellation")
)
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")

File diff suppressed because it is too large Load Diff

View File

@@ -1,67 +0,0 @@
from conftest import create_test_module
# Sample code for tools
FRIENDLY_FUNC_SOURCE_CODE = '''
def friendly_func():
"""
Returns a friendly message.
Returns:
str: A friendly message.
"""
return "HI HI HI HI HI!"
'''
UNFRIENDLY_FUNC_SOURCE_CODE = '''
def unfriendly_func():
"""
Returns an unfriendly message.
Returns:
str: An unfriendly message.
"""
return "NO NO NO NO NO!"
'''
UNFRIENDLY_FUNC_SOURCE_CODE_V2 = '''
def unfriendly_func():
"""
Returns an unfriendly message.
Returns:
str: An unfriendly message.
"""
return "BYE BYE BYE BYE BYE!"
'''
# Define test parameters for tools
TOOLS_CREATE_PARAMS = [
("friendly_func", {"source_code": FRIENDLY_FUNC_SOURCE_CODE}, {"name": "friendly_func"}, None),
("unfriendly_func", {"source_code": UNFRIENDLY_FUNC_SOURCE_CODE}, {"name": "unfriendly_func"}, None),
]
TOOLS_UPSERT_PARAMS = [
("unfriendly_func", {"source_code": UNFRIENDLY_FUNC_SOURCE_CODE_V2}, {}, None),
]
TOOLS_UPDATE_PARAMS = [
("friendly_func", {"tags": ["sdk_test"]}, {}, None),
("unfriendly_func", {"return_char_limit": 300}, {}, None),
]
TOOLS_LIST_PARAMS = [
({}, 2),
({"name": "friendly_func"}, 1),
]
# Create all test module components at once
globals().update(
create_test_module(
resource_name="tools",
id_param_name="tool_id",
create_params=TOOLS_CREATE_PARAMS,
upsert_params=TOOLS_UPSERT_PARAMS,
update_params=TOOLS_UPDATE_PARAMS,
list_params=TOOLS_LIST_PARAMS,
)
)

View File

@@ -4,8 +4,9 @@ import uuid
import pytest
from dotenv import load_dotenv
from letta_client import AgentState, Letta, MessageCreate
from letta_client.core.api_error import ApiError
from letta_client import APIError, Letta
from letta_client.types import MessageCreateParam
from letta_client.types.agent_state import AgentState
from sqlalchemy import delete
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
@@ -48,7 +49,7 @@ def client(request):
# Overide the base_url if the LETTA_API_URL is set
base_url = api_url if api_url else server_url
# create the Letta client
yield Letta(base_url=base_url, token=None)
yield Letta(base_url=base_url)
# Fixture for test agent
@@ -127,31 +128,31 @@ def test_add_and_manage_tags_for_agent(client: Letta):
assert len(agent.tags) == 0
# Step 1: Add multiple tags to the agent
client.agents.modify(agent_id=agent.id, tags=tags_to_add)
client.agents.update(agent_id=agent.id, tags=tags_to_add)
# Step 2: Retrieve tags for the agent and verify they match the added tags
retrieved_tags = client.agents.retrieve(agent_id=agent.id).tags
retrieved_tags = client.agents.retrieve(agent_id=agent.id, include=["agent.tags"]).tags
assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}"
# Step 3: Retrieve agents by each tag to ensure the agent is associated correctly
for tag in tags_to_add:
agents_with_tag = client.agents.list(tags=[tag])
agents_with_tag = client.agents.list(tags=[tag]).items
assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'"
# Step 4: Delete a specific tag from the agent and verify its removal
tag_to_delete = tags_to_add.pop()
client.agents.modify(agent_id=agent.id, tags=tags_to_add)
client.agents.update(agent_id=agent.id, tags=tags_to_add)
# Verify the tag is removed from the agent's tags
remaining_tags = client.agents.retrieve(agent_id=agent.id).tags
remaining_tags = client.agents.retrieve(agent_id=agent.id, include=["agent.tags"]).tags
assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected"
assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}"
# Step 5: Delete all remaining tags from the agent
client.agents.modify(agent_id=agent.id, tags=[])
client.agents.update(agent_id=agent.id, tags=[])
# Verify all tags are removed
final_tags = client.agents.retrieve(agent_id=agent.id).tags
final_tags = client.agents.retrieve(agent_id=agent.id, include=["agent.tags"]).tags
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
# Remove agent
@@ -258,15 +259,15 @@ def test_update_agent_memory_label(client: Letta):
agent = client.agents.create(model="letta/letta-free", embedding="letta/letta-free", memory_blocks=[{"label": "human", "value": ""}])
try:
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id).items]
example_label = current_labels[0]
example_new_label = "example_new_label"
assert example_new_label not in [b.label for b in client.agents.blocks.list(agent_id=agent.id)]
assert example_new_label not in [b.label for b in client.agents.blocks.list(agent_id=agent.id).items]
client.agents.blocks.modify(agent_id=agent.id, block_label=example_label, label=example_new_label)
client.agents.blocks.update(agent_id=agent.id, block_label=example_label, label=example_new_label)
updated_blocks = client.agents.blocks.list(agent_id=agent.id)
assert example_new_label in [b.label for b in updated_blocks]
assert example_new_label in [b.label for b in updated_blocks.items]
finally:
client.agents.delete(agent.id)
@@ -275,7 +276,7 @@ def test_update_agent_memory_label(client: Letta):
def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState):
"""Test that we can add and remove a block from an agent's memory"""
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id).items]
example_new_label = current_labels[0] + "_v2"
example_new_value = "example value"
assert example_new_label not in current_labels
@@ -290,14 +291,14 @@ def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState):
agent_id=agent.id,
block_id=block.id,
)
assert example_new_label in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
assert example_new_label in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id).items]
# Now unlink the block
updated_agent = client.agents.blocks.detach(
agent_id=agent.id,
block_id=block.id,
)
assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id).items]
def test_update_agent_memory_limit(client: Letta):
@@ -312,11 +313,11 @@ def test_update_agent_memory_limit(client: Letta):
],
)
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id).items]
example_label = current_labels[0]
example_new_limit = 1
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id).items]
example_label = current_labels[0]
example_new_limit = 1
current_block = client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label)
@@ -326,8 +327,8 @@ def test_update_agent_memory_limit(client: Letta):
assert example_new_limit < current_block_length
# We expect this to throw a value error
with pytest.raises(ApiError):
client.agents.blocks.modify(
with pytest.raises(APIError):
client.agents.blocks.update(
agent_id=agent.id,
block_label=example_label,
limit=example_new_limit,
@@ -336,7 +337,7 @@ def test_update_agent_memory_limit(client: Letta):
# Now try the same thing with a higher limit
example_new_limit = current_block_length + 10000
assert example_new_limit > current_block_length
client.agents.blocks.modify(
client.agents.blocks.update(
agent_id=agent.id,
block_label=example_label,
limit=example_new_limit,
@@ -381,7 +382,7 @@ def test_function_always_error(client: Letta):
# get function response
response = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="call the testing_method function and tell me the result")],
messages=[MessageCreateParam(role="user", content="call the testing_method function and tell me the result")],
)
print(response.messages)
@@ -420,23 +421,25 @@ def test_attach_detach_agent_tool(client: Letta, agent: AgentState):
tool = client.tools.upsert_from_function(func=example_tool)
# Initially tool should not be attached
initial_tools = client.agents.tools.list(agent_id=agent.id)
initial_tools = client.agents.tools.list(agent_id=agent.id).items
assert tool.id not in [t.id for t in initial_tools]
# Attach tool
new_agent_state = client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id)
client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id)
new_agent_state = client.agents.retrieve(agent_id=agent.id, include=["agent.tools"])
assert tool.id in [t.id for t in new_agent_state.tools]
# Verify tool is attached
updated_tools = client.agents.tools.list(agent_id=agent.id)
updated_tools = client.agents.tools.list(agent_id=agent.id).items
assert tool.id in [t.id for t in updated_tools]
# Detach tool
new_agent_state = client.agents.tools.detach(agent_id=agent.id, tool_id=tool.id)
client.agents.tools.detach(agent_id=agent.id, tool_id=tool.id)
new_agent_state = client.agents.retrieve(agent_id=agent.id, include=["agent.tools"])
assert tool.id not in [t.id for t in new_agent_state.tools]
# Verify tool is detached
final_tools = client.agents.tools.list(agent_id=agent.id)
final_tools = client.agents.tools.list(agent_id=agent.id).items
assert tool.id not in [t.id for t in final_tools]
finally:
@@ -449,10 +452,12 @@ def test_attach_detach_agent_tool(client: Letta, agent: AgentState):
def test_messages(client: Letta, agent: AgentState):
# _reset_config()
send_message_response = client.agents.messages.create(agent_id=agent.id, messages=[MessageCreate(role="user", content="Test message")])
send_message_response = client.agents.messages.create(
agent_id=agent.id, messages=[MessageCreateParam(role="user", content="Test message")]
)
assert send_message_response, "Sending message failed"
messages_response = client.agents.messages.list(agent_id=agent.id, limit=1)
messages_response = client.agents.messages.list(agent_id=agent.id, limit=1).items
assert len(messages_response) > 0, "Retrieving messages failed"
@@ -466,7 +471,7 @@ def test_messages(client: Letta, agent: AgentState):
# # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
# async def send_message_task(message: str):
# response = await asyncio.to_thread(
# client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)]
# client.agents.messages.create, agent_id=agent.id, messages=[MessageCreateParam(role="user", content=message)]
# )
# assert response, f"Sending message '{message}' failed"
# return response
@@ -497,23 +502,23 @@ def test_messages(client: Letta, agent: AgentState):
def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two):
"""Test listing agents with pagination and query text filtering."""
# Test query text filtering
search_results = client.agents.list(query_text="search agent")
search_results = client.agents.list(query_text="search agent").items
assert len(search_results) == 2
search_agent_ids = {agent.id for agent in search_results}
assert search_agent_one.id in search_agent_ids
assert search_agent_two.id in search_agent_ids
assert agent.id not in search_agent_ids
different_results = client.agents.list(query_text="client")
different_results = client.agents.list(query_text="client").items
assert len(different_results) == 1
assert different_results[0].id == agent.id
# Test pagination
first_page = client.agents.list(query_text="search agent", limit=1)
first_page = client.agents.list(query_text="search agent", limit=1).items
assert len(first_page) == 1
first_agent = first_page[0]
second_page = client.agents.list(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor
second_page = client.agents.list(query_text="search agent", after=first_agent.id, limit=1).items # Use agent ID as cursor
assert len(second_page) == 1
assert second_page[0].id != first_agent.id
@@ -523,7 +528,7 @@ def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two)
assert all_ids == {search_agent_one.id, search_agent_two.id}
# Test listing without any filters; make less flakey by checking we have at least 3 agents in case created elsewhere
all_agents = client.agents.list()
all_agents = client.agents.list().items
assert len(all_agents) >= 3
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])
@@ -569,7 +574,7 @@ def test_agent_creation(client: Letta):
assert agent.id is not None
# Verify the blocks are properly attached
agent_blocks = client.agents.blocks.list(agent_id=agent.id)
agent_blocks = client.agents.blocks.list(agent_id=agent.id).items
agent_block_ids = {block.id for block in agent_blocks}
# Check that all memory blocks are present
@@ -579,7 +584,7 @@ def test_agent_creation(client: Letta):
assert user_preferences_block.id in agent_block_ids, f"User preferences block {user_preferences_block.id} not attached to agent"
# Verify the tools are properly attached
agent_tools = client.agents.tools.list(agent_id=agent.id)
agent_tools = client.agents.tools.list(agent_id=agent.id).items
assert len(agent_tools) == 2
tool_ids = {tool1.id, tool2.id}
assert all(tool.id in tool_ids for tool in agent_tools)
@@ -597,20 +602,20 @@ def test_initial_sequence(client: Letta):
model="letta/letta-free",
embedding="letta/letta-free",
initial_message_sequence=[
MessageCreate(
MessageCreateParam(
role="assistant",
content="Hello, how are you?",
),
MessageCreate(role="user", content="I'm good, and you?"),
MessageCreateParam(role="user", content="I'm good, and you?"),
],
)
# list messages
messages = client.agents.messages.list(agent_id=agent.id)
messages = client.agents.messages.list(agent_id=agent.id).items
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
MessageCreateParam(
role="user",
content="hello assistant!",
)
@@ -637,7 +642,7 @@ def test_initial_sequence(client: Letta):
# response = client.agents.messages.create(
# agent_id=agent.id,
# messages=[
# MessageCreate(
# MessageCreateParam(
# role="user",
# content="What timezone are you in?",
# )
@@ -653,7 +658,7 @@ def test_initial_sequence(client: Letta):
# )
#
# # test updating the timezone
# client.agents.modify(agent_id=agent.id, timezone="America/New_York")
# client.agents.update(agent_id=agent.id, timezone="America/New_York")
# agent = client.agents.retrieve(agent_id=agent.id)
# assert agent.timezone == "America/New_York"
@@ -678,10 +683,10 @@ def test_attach_sleeptime_block(client: Letta):
client.agents.blocks.attach(agent_id=agent.id, block_id=block.id)
# verify block is attached to both agents
blocks = client.agents.blocks.list(agent_id=agent.id)
blocks = client.agents.blocks.list(agent_id=agent.id).items
assert block.id in [b.id for b in blocks]
blocks = client.agents.blocks.list(agent_id=sleeptime_id)
blocks = client.agents.blocks.list(agent_id=sleeptime_id).items
assert block.id in [b.id for b in blocks]
# blocks = client.blocks.list(project_id="test")

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,8 @@ import time
from datetime import datetime, timezone
from typing import Dict, Iterator, List, Optional, Tuple
from letta_client import Letta, SystemMessage
from letta_client import Letta
from letta_client.types.agents.system_message import SystemMessage
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector

15
uv.lock generated
View File

@@ -2335,7 +2335,7 @@ wheels = [
[[package]]
name = "letta"
version = "0.14.0"
version = "0.14.1"
source = { editable = "." }
dependencies = [
{ name = "aiomultiprocess" },
@@ -2522,7 +2522,7 @@ requires-dist = [
{ name = "langchain", marker = "extra == 'external-tools'", specifier = ">=0.3.7" },
{ name = "langchain-community", marker = "extra == 'desktop'", specifier = ">=0.3.7" },
{ name = "langchain-community", marker = "extra == 'external-tools'", specifier = ">=0.3.7" },
{ name = "letta-client", specifier = ">=0.1.319" },
{ name = "letta-client", specifier = ">=1.1.2" },
{ name = "llama-index", specifier = ">=0.12.2" },
{ name = "llama-index-embeddings-openai", specifier = ">=0.3.1" },
{ name = "locust", marker = "extra == 'desktop'", specifier = ">=2.31.5" },
@@ -2599,18 +2599,19 @@ provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "s
[[package]]
name = "letta-client"
version = "0.1.319"
version = "1.1.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "httpx" },
{ name = "httpx-sse" },
{ name = "pydantic" },
{ name = "pydantic-core" },
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/07/48/8a70ff23e9adcf7b3a9262b03fd0576eae03bafb61b7d229e1059c16ce7c/letta_client-0.1.319.tar.gz", hash = "sha256:30a2bd63d5e27759ca57a3850f2be3d81d828e90b5a7a6c35285b4ecceaafc74", size = 197085, upload-time = "2025-09-08T23:17:40.636Z" }
sdist = { url = "https://files.pythonhosted.org/packages/28/8c/b31ad4bc3fad1c563b4467762b67f7eca7bc65cb2c0c2ca237b6b6a485ae/letta_client-1.1.2.tar.gz", hash = "sha256:2687b3aebc31401db4f273719db459b2a7a2a527779b87d56c53d2bdf664e4d3", size = 231825, upload-time = "2025-11-21T03:06:06.737Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1a/78/52d64b29ce0ffcd5cc3f1318a5423c9ead95c1388e8a95bd6156b7335ad3/letta_client-0.1.319-py3-none-any.whl", hash = "sha256:e93cda21d39de21bf2353f1aa71e82054eac209156bd4f1780efff85949a32d3", size = 493310, upload-time = "2025-09-08T23:17:39.134Z" },
{ url = "https://files.pythonhosted.org/packages/ba/22/ec950b367a3cc5e2c8ae426e84c13a2d824ea4e337dbd7b94300a1633929/letta_client-1.1.2-py3-none-any.whl", hash = "sha256:86d9c6f2f9e773965f2107898584e8650f3843f938604da9fd1dfe6a462af533", size = 357516, upload-time = "2025-11-21T03:06:05.559Z" },
]
[[package]]