chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import * remove all ignores, add FastAPI rules and Ruff rules * add ty and precommit * ruff stuff * ty check fixes * ty check fixes pt 2 * error on invalid
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
|
||||
from letta import RESTClient
|
||||
from letta import RESTClient # type: ignore[attr-defined]
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.source import Source
|
||||
|
||||
@@ -142,7 +142,7 @@ def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], k
|
||||
send_message_function_call = target_message.tool_call
|
||||
try:
|
||||
arguments = json.loads(send_message_function_call.arguments)
|
||||
except:
|
||||
except Exception:
|
||||
raise InvalidToolCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON")
|
||||
|
||||
# Message field not in send_message
|
||||
|
||||
@@ -234,9 +234,9 @@ async def test_web_search() -> None:
|
||||
|
||||
# Check for education-related information in summary and highlights
|
||||
result_text = ""
|
||||
if "summary" in result and result["summary"]:
|
||||
if result.get("summary"):
|
||||
result_text += " " + result["summary"].lower()
|
||||
if "highlights" in result and result["highlights"]:
|
||||
if result.get("highlights"):
|
||||
for highlight in result["highlights"]:
|
||||
result_text += " " + highlight.lower()
|
||||
|
||||
|
||||
@@ -318,7 +318,7 @@ def get_secret_code(input_text: str) -> str:
|
||||
print(" ✓ Without client_tools, server tool executed directly (no approval required)")
|
||||
|
||||
# The response should eventually contain the server value
|
||||
all_content = " ".join([msg.content for msg in response4.messages if hasattr(msg, "content") and msg.content])
|
||||
" ".join([msg.content for msg in response4.messages if hasattr(msg, "content") and msg.content])
|
||||
tool_returns = [msg for msg in response4.messages if msg.message_type == "tool_return_message"]
|
||||
if tool_returns:
|
||||
server_return_value = tool_returns[0].tool_return
|
||||
|
||||
@@ -362,7 +362,7 @@ def test_invoke_tool_after_turning_off_requires_approval(
|
||||
try:
|
||||
assert messages[idx].message_type == "assistant_message"
|
||||
idx += 1
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert messages[idx].message_type == "tool_call_message"
|
||||
@@ -375,7 +375,7 @@ def test_invoke_tool_after_turning_off_requires_approval(
|
||||
try:
|
||||
assert messages[idx].message_type == "assistant_message"
|
||||
idx += 1
|
||||
except:
|
||||
except Exception:
|
||||
assert messages[idx].message_type == "tool_call_message"
|
||||
idx += 1
|
||||
assert messages[idx].message_type == "tool_return_message"
|
||||
@@ -1324,7 +1324,7 @@ def test_agent_records_last_stop_reason_after_approval_flow(
|
||||
assert agent_after_approval.last_stop_reason != initial_stop_reason # Should be different from initial
|
||||
|
||||
# Send follow-up message to complete the flow
|
||||
response2 = client.agents.messages.create(
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
)
|
||||
|
||||
@@ -86,7 +86,7 @@ def remove_stale_agents(client):
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_obj(client: Letta) -> AgentState:
|
||||
"""Create a test agent that we can call functions on"""
|
||||
send_message_to_agent_tool = list(client.tools.list(name="send_message_to_agent_and_wait_for_reply"))[0]
|
||||
send_message_to_agent_tool = next(iter(client.tools.list(name="send_message_to_agent_and_wait_for_reply")))
|
||||
agent_state_instance = client.agents.create(
|
||||
agent_type="letta_v1_agent",
|
||||
include_base_tools=True,
|
||||
@@ -218,7 +218,7 @@ def test_send_message_to_agents_with_tags_simple(client: Letta):
|
||||
secret_word = "banana"
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_tags_tool_id = list(client.tools.list(name="send_message_to_agents_matching_tags"))[0].id
|
||||
send_message_to_agents_matching_tags_tool_id = next(iter(client.tools.list(name="send_message_to_agents_matching_tags"))).id
|
||||
manager_agent_state = client.agents.create(
|
||||
agent_type="letta_v1_agent",
|
||||
name="manager_agent",
|
||||
@@ -329,7 +329,7 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client: Letta, roll_d
|
||||
test_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_tags_tool_id = list(client.tools.list(name="send_message_to_agents_matching_tags"))[0].id
|
||||
send_message_to_agents_matching_tags_tool_id = next(iter(client.tools.list(name="send_message_to_agents_matching_tags"))).id
|
||||
manager_agent_state = client.agents.create(
|
||||
agent_type="letta_v1_agent",
|
||||
tool_ids=[send_message_to_agents_matching_tags_tool_id],
|
||||
|
||||
@@ -370,7 +370,7 @@ def assert_greeting_with_assistant_message_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -508,7 +508,7 @@ def assert_greeting_without_assistant_message_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -664,7 +664,7 @@ def assert_tool_call_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -700,7 +700,7 @@ def assert_tool_call_response(
|
||||
assert isinstance(messages[index], (ReasoningMessage, HiddenReasoningMessage))
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -856,7 +856,7 @@ def assert_image_input_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -1889,7 +1889,7 @@ def test_async_greeting_with_assistant_message(
|
||||
|
||||
messages_page = client.runs.messages.list(run_id=run.id)
|
||||
messages = messages_page.items
|
||||
usage = client.runs.usage.retrieve(run_id=run.id)
|
||||
client.runs.usage.retrieve(run_id=run.id)
|
||||
|
||||
# TODO: add results API test later
|
||||
assert_greeting_with_assistant_message_response(messages, model_handle, model_settings, from_db=True) # TODO: remove from_db=True later
|
||||
@@ -2267,7 +2267,7 @@ def test_job_creation_for_send_message(
|
||||
assert len(new_runs) == 1
|
||||
|
||||
for run in runs:
|
||||
if run.id == list(new_runs)[0]:
|
||||
if run.id == next(iter(new_runs)):
|
||||
assert run.status == "completed"
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ from letta_client.types.agents.letta_streaming_response import LettaPing, LettaS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
@@ -132,7 +134,7 @@ def assert_greeting_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -203,7 +205,7 @@ def assert_tool_call_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -256,7 +258,7 @@ def assert_tool_call_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
except Exception:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
@@ -890,8 +892,10 @@ async def test_tool_call(
|
||||
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
|
||||
|
||||
if cancellation == "with_cancellation":
|
||||
delay = 5 if "gpt-5" in model_handle else 0.5 # increase delay for responses api
|
||||
delay = 5 if "gpt-5" in model_handle else 0.5
|
||||
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
|
||||
_background_tasks.add(_cancellation_task)
|
||||
_cancellation_task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
if send_type == "step":
|
||||
response = await client.agents.messages.create(
|
||||
|
||||
@@ -218,7 +218,7 @@ async def test_summarize_empty_message_buffer(server: SyncServer, actor, llm_con
|
||||
|
||||
# Run summarization - this may fail with empty buffer, which is acceptable behavior
|
||||
try:
|
||||
summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
_summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
# If it succeeds, verify result
|
||||
assert isinstance(result, list)
|
||||
|
||||
@@ -311,7 +311,7 @@ async def test_summarize_initialization_messages_only(server: SyncServer, actor,
|
||||
|
||||
# Run summarization - force=True with system messages only may fail
|
||||
try:
|
||||
summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
_summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -367,7 +367,7 @@ async def test_summarize_small_conversation(server: SyncServer, actor, llm_confi
|
||||
# Run summarization with force=True
|
||||
# Note: force=True with clear=True can be very aggressive and may fail on small message sets
|
||||
try:
|
||||
summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
_summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -460,7 +460,7 @@ async def test_summarize_large_tool_calls(server: SyncServer, actor, llm_config:
|
||||
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
|
||||
|
||||
# Run summarization
|
||||
summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
_summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -564,7 +564,7 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll
|
||||
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
|
||||
|
||||
# Run summarization
|
||||
summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
_summary, result, _ = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -724,7 +724,7 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
summary, result, summary_text = await agent_loop.compact(messages=in_context_messages)
|
||||
_summary, result, summary_text = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
@@ -810,7 +810,7 @@ async def test_compact_returns_valid_summary_message_and_event_message(server: S
|
||||
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
summary_message_obj, compacted_messages, summary_text = await agent_loop.compact(messages=in_context_messages)
|
||||
summary_message_obj, _compacted_messages, summary_text = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
# Verify we can construct a valid SummaryMessage from compact() return values
|
||||
summary_msg = SummaryMessage(
|
||||
@@ -971,7 +971,7 @@ async def test_v3_compact_uses_compaction_settings_model_and_model_settings(serv
|
||||
# Patch simple_summary so we don't hit the real LLM and can inspect llm_config
|
||||
with patch.object(summarizer_all, "simple_summary", new=fake_simple_summary):
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
summary_msg, compacted, _ = await agent_loop.compact(messages=in_context_messages)
|
||||
summary_msg, _compacted, _ = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
assert summary_msg is not None
|
||||
assert "value" in captured_llm_config
|
||||
@@ -1059,7 +1059,7 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
|
||||
caplog.set_level("ERROR")
|
||||
|
||||
summary, result, summary_text = await agent_loop.compact(
|
||||
_summary, result, summary_text = await agent_loop.compact(
|
||||
messages=in_context_messages,
|
||||
trigger_threshold=context_limit,
|
||||
)
|
||||
@@ -2015,7 +2015,7 @@ async def test_compact_with_stats_params_embeds_stats(server: SyncServer, actor,
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# Call compact with stats params
|
||||
summary_message_obj, compacted_messages, summary_text = await agent_loop.compact(
|
||||
summary_message_obj, compacted_messages, _summary_text = await agent_loop.compact(
|
||||
messages=in_context_messages,
|
||||
use_summary_role=True,
|
||||
trigger="post_step_context_check",
|
||||
|
||||
@@ -45,7 +45,7 @@ async def sarah_agent(server, default_user):
|
||||
# Cleanup
|
||||
try:
|
||||
await server.agent_manager.delete_agent_async(agent.id, default_user)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ async def wait_for_embedding(
|
||||
if any(msg["id"] == message_id for msg, _, _ in results):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Log but don't fail - Turbopuffer might still be processing
|
||||
pass
|
||||
|
||||
@@ -347,7 +347,7 @@ async def test_turbopuffer_metadata_attributes(default_user, enable_turbopuffer)
|
||||
# Clean up on error
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
||||
@@ -409,7 +409,7 @@ async def test_hybrid_search_with_real_tpuf(default_user, enable_turbopuffer):
|
||||
]
|
||||
|
||||
# Create simple embeddings for testing (normally you'd use a real embedding model)
|
||||
embeddings = [[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
|
||||
[[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
|
||||
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
|
||||
|
||||
# Insert passages
|
||||
@@ -487,7 +487,7 @@ async def test_hybrid_search_with_real_tpuf(default_user, enable_turbopuffer):
|
||||
# Clean up
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -522,7 +522,7 @@ async def test_tag_filtering_with_real_tpuf(default_user, enable_turbopuffer):
|
||||
["javascript", "react"],
|
||||
]
|
||||
|
||||
embeddings = [[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
|
||||
[[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
|
||||
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
|
||||
|
||||
# Insert passages with tags
|
||||
@@ -615,7 +615,7 @@ async def test_tag_filtering_with_real_tpuf(default_user, enable_turbopuffer):
|
||||
# Clean up
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -754,7 +754,7 @@ async def test_temporal_filtering_with_real_tpuf(default_user, enable_turbopuffe
|
||||
# Clean up
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -865,12 +865,6 @@ def test_message_text_extraction(server, default_user):
|
||||
agent_id="test-agent",
|
||||
)
|
||||
text6 = manager._extract_message_text(msg6)
|
||||
expected_parts = [
|
||||
"User said:",
|
||||
'Tool call: search({\n "query": "test"\n})',
|
||||
"Tool result: Found 5 results",
|
||||
"I should help the user",
|
||||
]
|
||||
assert (
|
||||
text6
|
||||
== '{"content": "User said: Tool call: search({\\n \\"query\\": \\"test\\"\\n}) Tool result: Found 5 results I should help the user"}'
|
||||
@@ -1112,7 +1106,7 @@ async def test_message_dual_write_with_real_tpuf(enable_message_embedding, defau
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Generate embeddings (dummy for test)
|
||||
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
|
||||
[[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
|
||||
|
||||
# Insert messages into Turbopuffer
|
||||
success = await client.insert_messages(
|
||||
@@ -1144,7 +1138,7 @@ async def test_message_dual_write_with_real_tpuf(enable_message_embedding, defau
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1205,7 +1199,7 @@ async def test_message_vector_search_with_real_tpuf(enable_message_embedding, de
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1268,7 +1262,7 @@ async def test_message_hybrid_search_with_real_tpuf(enable_message_embedding, de
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1340,7 +1334,7 @@ async def test_message_role_filtering_with_real_tpuf(enable_message_embedding, d
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1357,7 +1351,7 @@ async def test_message_search_fallback_to_sql(server, default_user, sarah_agent)
|
||||
settings.embed_all_messages = False
|
||||
|
||||
# Create messages
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
@@ -1398,7 +1392,7 @@ async def test_message_update_reindexes_in_turbopuffer(server, default_user, sar
|
||||
"""Test that updating a message properly deletes and re-inserts with new embedding in Turbopuffer"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create initial message
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
@@ -1493,8 +1487,6 @@ async def test_message_deletion_syncs_with_turbopuffer(server, default_user, ena
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
embedding_config = agent_a.embedding_config
|
||||
|
||||
try:
|
||||
# Create 5 messages for agent A
|
||||
agent_a_messages = []
|
||||
@@ -1597,7 +1589,7 @@ async def test_turbopuffer_failure_does_not_break_postgres(server, default_user,
|
||||
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create initial messages
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
@@ -1668,7 +1660,7 @@ async def test_turbopuffer_failure_does_not_break_postgres(server, default_user,
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_creation_background_mode(server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that messages are embedded in background when strict_mode=False"""
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create message in background mode
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
@@ -1723,7 +1715,7 @@ async def test_message_update_background_mode(server, default_user, sarah_agent,
|
||||
"""Test that message updates work in background mode"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create initial message with strict_mode=True to ensure it's embedded
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
@@ -1899,7 +1891,7 @@ async def test_message_date_filtering_with_real_tpuf(enable_message_embedding, d
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -2403,7 +2395,7 @@ async def test_query_messages_with_mixed_conversation_id_presence(enable_message
|
||||
async with AsyncTurbopuffer(api_key=client.api_key, region=client.region) as tpuf:
|
||||
namespace = tpuf.namespace(namespace_name)
|
||||
await namespace.delete_all()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -2485,7 +2477,7 @@ async def test_query_messages_by_org_id_with_missing_conversation_id_schema(enab
|
||||
async with AsyncTurbopuffer(api_key=client.api_key, region=client.region) as tpuf:
|
||||
namespace = tpuf.namespace(namespace_name)
|
||||
await namespace.delete_all()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -2541,5 +2533,5 @@ async def test_system_messages_not_embedded_during_agent_creation(server, defaul
|
||||
# Clean up
|
||||
try:
|
||||
await server.agent_manager.delete_agent_async(agent.id, default_user)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -400,7 +400,7 @@ async def test_run_level_usage_aggregation(
|
||||
|
||||
try:
|
||||
# Send multiple messages to create multiple steps
|
||||
response1: Run = await async_client.agents.messages.send_message(
|
||||
await async_client.agents.messages.send_message(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreateParam(role="user", content="Message 1")],
|
||||
)
|
||||
|
||||
@@ -497,7 +497,7 @@ async def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, de
|
||||
@pytest.fixture
|
||||
async def file_attachment(server: SyncServer, default_user, sarah_agent, default_file):
|
||||
"""Create a file attachment to an agent."""
|
||||
assoc, closed_files = await server.file_agent_manager.attach_file(
|
||||
assoc, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
|
||||
@@ -279,7 +279,7 @@ async def test_compaction_settings_model_uses_separate_llm_config_for_summarizat
|
||||
)
|
||||
|
||||
# Minimal message buffer: system + one user + one assistant
|
||||
messages = [
|
||||
[
|
||||
PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
@@ -500,10 +500,10 @@ async def test_get_context_window_basic(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, set_letta_environment
|
||||
):
|
||||
# Test agent creation
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Attach a file
|
||||
assoc, closed_files = await server.file_agent_manager.attach_file(
|
||||
assoc, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=created_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
@@ -879,7 +879,7 @@ async def test_update_agent_last_stop_reason(server: SyncServer, comprehensive_t
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
# Create an agent using the comprehensive fixture.
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
_created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# List agents using an empty list for select_fields.
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=[])
|
||||
@@ -897,7 +897,7 @@ async def test_list_agents_select_fields_empty(server: SyncServer, comprehensive
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
# Create an agent using the comprehensive fixture.
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
_created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# List agents using an empty list for select_fields.
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=None)
|
||||
@@ -914,7 +914,7 @@ async def test_list_agents_select_fields_none(server: SyncServer, comprehensive_
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
_created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Choose a subset of valid relationship fields.
|
||||
valid_fields = ["tools", "tags"]
|
||||
@@ -931,7 +931,7 @@ async def test_list_agents_select_fields_specific(server: SyncServer, comprehens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
_created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Provide field names that are not recognized.
|
||||
invalid_fields = ["foobar", "nonexistent_field"]
|
||||
@@ -946,7 +946,7 @@ async def test_list_agents_select_fields_invalid(server: SyncServer, comprehensi
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
_created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Provide duplicate valid field names.
|
||||
duplicate_fields = ["tools", "tools", "tags", "tags"]
|
||||
@@ -961,7 +961,7 @@ async def test_list_agents_select_fields_duplicates(server: SyncServer, comprehe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
_created_agent, _create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Mix valid fields with an invalid one.
|
||||
mixed_fields = ["tools", "invalid_field"]
|
||||
@@ -978,7 +978,7 @@ async def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
# Create two agents with known names
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_oldest",
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -993,7 +993,7 @@ async def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_newest",
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -1013,7 +1013,7 @@ async def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_descending(server: SyncServer, default_user):
|
||||
# Create two agents with known names
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_oldest",
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -1028,7 +1028,7 @@ async def test_list_agents_descending(server: SyncServer, default_user):
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_newest",
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -1084,7 +1084,7 @@ async def test_list_agents_by_last_stop_reason(server: SyncServer, default_user)
|
||||
)
|
||||
|
||||
# Create agent with no stop reason
|
||||
agent3 = await server.agent_manager.create_agent_async(
|
||||
await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_no_stop_reason",
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -1172,7 +1172,7 @@ async def test_count_agents_with_filters(server: SyncServer, default_user):
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent4 = await server.agent_manager.create_agent_async(
|
||||
await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_no_stop_reason",
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -1963,14 +1963,14 @@ async def test_create_template_agent_with_files_from_sources(server: SyncServer,
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=source.id,
|
||||
)
|
||||
file1 = await server.file_manager.create_file(file_metadata=file1_metadata, actor=default_user, text="content for file 1")
|
||||
await server.file_manager.create_file(file_metadata=file1_metadata, actor=default_user, text="content for file 1")
|
||||
|
||||
file2_metadata = PydanticFileMetadata(
|
||||
file_name="template_file_2.txt",
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=source.id,
|
||||
)
|
||||
file2 = await server.file_manager.create_file(file_metadata=file2_metadata, actor=default_user, text="content for file 2")
|
||||
await server.file_manager.create_file(file_metadata=file2_metadata, actor=default_user, text="content for file 2")
|
||||
|
||||
# Create agent using InternalTemplateAgentCreate with the source
|
||||
create_agent_request = InternalTemplateAgentCreate(
|
||||
|
||||
@@ -320,7 +320,7 @@ class TestMessageStateDesyncIssues:
|
||||
print(f" background={request.background}")
|
||||
|
||||
# Start the background streaming agent
|
||||
run, stream_response = await streaming_service.create_agent_stream(
|
||||
run, _stream_response = await streaming_service.create_agent_stream(
|
||||
agent_id=test_agent_with_tool.id,
|
||||
actor=default_user,
|
||||
request=request,
|
||||
@@ -510,7 +510,7 @@ class TestStreamingCancellation:
|
||||
try:
|
||||
async for chunk in cancel_during_stream():
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# May raise exception on cancellation
|
||||
pass
|
||||
|
||||
@@ -733,7 +733,7 @@ class TestResourceCleanupAfterCancellation:
|
||||
|
||||
input_messages = [MessageCreate(role=MessageRole.user, content="Call print_tool with 'test'")]
|
||||
|
||||
result = await agent_loop.step(
|
||||
await agent_loop.step(
|
||||
input_messages=input_messages,
|
||||
max_steps=5,
|
||||
run_id=test_run.id,
|
||||
@@ -895,7 +895,7 @@ class TestApprovalFlowCancellation:
|
||||
)
|
||||
|
||||
# Check for approval request messages
|
||||
approval_messages = [m for m in messages_after_cancel if m.role == "approval_request"]
|
||||
[m for m in messages_after_cancel if m.role == "approval_request"]
|
||||
|
||||
# Second run: try to execute normally (should work, not stuck in approval)
|
||||
test_run_2 = await server.run_manager.create_run(
|
||||
@@ -1075,7 +1075,7 @@ class TestApprovalFlowCancellation:
|
||||
assert result.stop_reason.stop_reason == "requires_approval", f"Expected requires_approval, got {result.stop_reason.stop_reason}"
|
||||
|
||||
# Get all messages from database for this run
|
||||
db_messages = await server.message_manager.list_messages(
|
||||
await server.message_manager.list_messages(
|
||||
actor=default_user,
|
||||
agent_id=test_agent_with_tool.id,
|
||||
run_id=test_run.id,
|
||||
@@ -1210,7 +1210,7 @@ class TestApprovalFlowCancellation:
|
||||
assert result.stop_reason.stop_reason == "requires_approval", f"Should stop for approval, got {result.stop_reason.stop_reason}"
|
||||
|
||||
# Get the approval request message to see how many tool calls were made
|
||||
db_messages_before_cancel = await server.message_manager.list_messages(
|
||||
await server.message_manager.list_messages(
|
||||
actor=default_user,
|
||||
agent_id=agent_state.id,
|
||||
run_id=test_run.id,
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_creates_association(server, default_user, sarah_agent, default_file):
|
||||
assoc, closed_files = await server.file_agent_manager.attach_file(
|
||||
assoc, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
@@ -40,7 +40,7 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def
|
||||
|
||||
|
||||
async def test_attach_is_idempotent(server, default_user, sarah_agent, default_file):
|
||||
a1, closed_files = await server.file_agent_manager.attach_file(
|
||||
a1, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
@@ -51,7 +51,7 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f
|
||||
)
|
||||
|
||||
# second attach with different params
|
||||
a2, closed_files = await server.file_agent_manager.attach_file(
|
||||
a2, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
file_name=default_file.file_name,
|
||||
@@ -114,7 +114,7 @@ async def test_file_agent_line_tracking(server, default_user, sarah_agent, defau
|
||||
file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text=test_content)
|
||||
|
||||
# Test opening with line range using enforce_max_open_files_and_open
|
||||
closed_files, was_already_open, previous_ranges = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
_closed_files, _was_already_open, previous_ranges = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
@@ -138,7 +138,7 @@ async def test_file_agent_line_tracking(server, default_user, sarah_agent, defau
|
||||
assert previous_ranges == {} # No previous range since it wasn't open before
|
||||
|
||||
# Test opening without line range - should clear line info and capture previous range
|
||||
closed_files, was_already_open, previous_ranges = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
_closed_files, _was_already_open, previous_ranges = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
@@ -321,7 +321,7 @@ async def test_list_files_for_agent_paginated_filter_open(
|
||||
)
|
||||
|
||||
# get only open files
|
||||
open_files, cursor, has_more = await server.file_agent_manager.list_files_for_agent_paginated(
|
||||
open_files, _cursor, has_more = await server.file_agent_manager.list_files_for_agent_paginated(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
is_open=True,
|
||||
@@ -370,7 +370,7 @@ async def test_list_files_for_agent_paginated_filter_closed(
|
||||
assert all(not fa.is_open for fa in page1)
|
||||
|
||||
# get second page of closed files
|
||||
page2, cursor2, has_more2 = await server.file_agent_manager.list_files_for_agent_paginated(
|
||||
page2, _cursor2, has_more2 = await server.file_agent_manager.list_files_for_agent_paginated(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
is_open=False,
|
||||
@@ -586,7 +586,7 @@ async def test_mark_access_bulk(server, default_user, sarah_agent, default_sourc
|
||||
# Attach all files (they'll be open by default)
|
||||
attached_files = []
|
||||
for file in files:
|
||||
file_agent, closed_files = await server.file_agent_manager.attach_file(
|
||||
file_agent, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
@@ -745,7 +745,7 @@ async def test_lru_eviction_on_open_file(server, default_user, sarah_agent, defa
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now "open" the last file using the efficient method
|
||||
closed_files, was_already_open, _ = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
closed_files, _was_already_open, _ = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=files[-1].id,
|
||||
file_name=files[-1].file_name,
|
||||
@@ -853,7 +853,7 @@ async def test_last_accessed_at_updates_correctly(server, default_user, sarah_ag
|
||||
)
|
||||
file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user, text="test content")
|
||||
|
||||
file_agent, closed_files = await server.file_agent_manager.attach_file(
|
||||
file_agent, _closed_files = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
@@ -957,7 +957,7 @@ async def test_attach_files_bulk_deduplication(server, default_user, sarah_agent
|
||||
visible_content_map = {"duplicate_test.txt": "visible content"}
|
||||
|
||||
# Bulk attach should deduplicate
|
||||
closed_files = await server.file_agent_manager.attach_files_bulk(
|
||||
await server.file_agent_manager.attach_files_bulk(
|
||||
agent_id=sarah_agent.id,
|
||||
files_metadata=files_to_attach,
|
||||
visible_content_map=visible_content_map,
|
||||
@@ -1085,7 +1085,7 @@ async def test_attach_files_bulk_mixed_existing_new(server, default_user, sarah_
|
||||
new_files.append(file)
|
||||
|
||||
# Bulk attach: existing file + new files
|
||||
files_to_attach = [existing_file] + new_files
|
||||
files_to_attach = [existing_file, *new_files]
|
||||
visible_content_map = {
|
||||
"existing_file.txt": "updated content",
|
||||
"new_file_0.txt": "new content 0",
|
||||
|
||||
@@ -306,31 +306,3 @@ async def test_get_set_blocks_for_identities(server: SyncServer, default_block,
|
||||
assert block_without_identity.id not in block_ids
|
||||
|
||||
await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user)
|
||||
|
||||
|
||||
async def test_upsert_properties(server: SyncServer, default_user):
|
||||
identity_create = IdentityCreate(
|
||||
identifier_key="1234",
|
||||
name="caren",
|
||||
identity_type=IdentityType.user,
|
||||
properties=[
|
||||
IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string),
|
||||
IdentityProperty(key="age", value=28, type=IdentityPropertyType.number),
|
||||
],
|
||||
)
|
||||
|
||||
identity = await server.identity_manager.create_identity_async(identity_create, actor=default_user)
|
||||
properties = [
|
||||
IdentityProperty(key="email", value="caren@gmail.com", type=IdentityPropertyType.string),
|
||||
IdentityProperty(key="age", value="28", type=IdentityPropertyType.string),
|
||||
IdentityProperty(key="test", value=123, type=IdentityPropertyType.number),
|
||||
]
|
||||
|
||||
updated_identity = await server.identity_manager.upsert_identity_properties_async(
|
||||
identity_id=identity.id,
|
||||
properties=properties,
|
||||
actor=default_user,
|
||||
)
|
||||
assert updated_identity.properties == properties
|
||||
|
||||
await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user)
|
||||
|
||||
@@ -71,7 +71,7 @@ async def test_create_mcp_server(mock_get_client, server, default_user):
|
||||
# Test with a valid SSEServerConfig
|
||||
mcp_server_name = "coingecko"
|
||||
server_url = "https://mcp.api.coingecko.com/sse"
|
||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user)
|
||||
print(created_server)
|
||||
@@ -797,7 +797,7 @@ async def test_mcp_server_resync_tools(server, default_user, default_organizatio
|
||||
|
||||
# Verify tool2 was actually deleted
|
||||
try:
|
||||
deleted_tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool2.id, actor=default_user)
|
||||
await server.tool_manager.get_tool_by_id_async(tool_id=tool2.id, actor=default_user)
|
||||
assert False, "Tool2 should have been deleted"
|
||||
except Exception:
|
||||
pass # Expected - tool should be deleted
|
||||
|
||||
@@ -214,10 +214,10 @@ async def test_modify_letta_message(server: SyncServer, sarah_agent, default_use
|
||||
messages = await server.message_manager.list_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages)
|
||||
|
||||
system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0]
|
||||
assistant_message = [msg for msg in letta_messages if msg.message_type == "assistant_message"][0]
|
||||
user_message = [msg for msg in letta_messages if msg.message_type == "user_message"][0]
|
||||
reasoning_message = [msg for msg in letta_messages if msg.message_type == "reasoning_message"][0]
|
||||
system_message = next(msg for msg in letta_messages if msg.message_type == "system_message")
|
||||
assistant_message = next(msg for msg in letta_messages if msg.message_type == "assistant_message")
|
||||
user_message = next(msg for msg in letta_messages if msg.message_type == "user_message")
|
||||
reasoning_message = next(msg for msg in letta_messages if msg.message_type == "reasoning_message")
|
||||
|
||||
# user message
|
||||
update_user_message = UpdateUserMessage(content="Hello, Sarah!")
|
||||
|
||||
@@ -908,7 +908,7 @@ async def test_server_startup_handles_api_errors_gracefully(default_user, defaul
|
||||
actor=default_user,
|
||||
)
|
||||
if len(openai_providers) > 0:
|
||||
openai_models = await server.provider_manager.list_models_async(
|
||||
await server.provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
provider_id=openai_providers[0].id,
|
||||
)
|
||||
|
||||
@@ -161,8 +161,7 @@ async def test_update_run_updates_agent_last_stop_reason(server: SyncServer, sar
|
||||
"""Test that completing a run updates the agent's last_stop_reason."""
|
||||
|
||||
# Verify agent starts with no last_stop_reason
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
initial_stop_reason = agent.last_stop_reason
|
||||
await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# Create a run
|
||||
run_data = PydanticRun(agent_id=sarah_agent.id)
|
||||
@@ -867,7 +866,7 @@ async def test_run_messages_ordering(server: SyncServer, default_run, default_us
|
||||
created_at=created_at,
|
||||
run_id=run.id,
|
||||
)
|
||||
msg = await server.message_manager.create_many_messages_async([message], actor=default_user)
|
||||
await server.message_manager.create_many_messages_async([message], actor=default_user)
|
||||
|
||||
# Verify messages are returned in chronological order
|
||||
returned_messages = await server.message_manager.list_messages(
|
||||
@@ -1015,7 +1014,7 @@ async def test_get_run_messages(server: SyncServer, default_user: PydanticUser,
|
||||
)
|
||||
)
|
||||
|
||||
created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user)
|
||||
await server.message_manager.create_many_messages_async(messages, actor=default_user)
|
||||
|
||||
# Get messages and verify they're converted correctly
|
||||
result = await server.message_manager.list_messages(run_id=run.id, actor=default_user)
|
||||
@@ -1088,7 +1087,7 @@ async def test_get_run_messages_with_assistant_message(server: SyncServer, defau
|
||||
)
|
||||
)
|
||||
|
||||
created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user)
|
||||
await server.message_manager.create_many_messages_async(messages, actor=default_user)
|
||||
|
||||
# Get messages and verify they're converted correctly
|
||||
result = await server.message_manager.list_messages(run_id=run.id, actor=default_user)
|
||||
@@ -1369,7 +1368,7 @@ async def test_run_metrics_duration_calculation(server: SyncServer, sarah_agent,
|
||||
await asyncio.sleep(0.1) # Wait 100ms
|
||||
|
||||
# Update the run to completed
|
||||
updated_run = await server.run_manager.update_run_by_id_async(
|
||||
await server.run_manager.update_run_by_id_async(
|
||||
created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user
|
||||
)
|
||||
|
||||
@@ -1663,7 +1662,7 @@ def test_convert_statuses_to_enum_with_invalid_status():
|
||||
async def test_list_runs_with_multiple_statuses(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing runs with multiple status filters."""
|
||||
# Create runs with different statuses
|
||||
run_created = await server.run_manager.create_run(
|
||||
await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
status=RunStatus.created,
|
||||
agent_id=sarah_agent.id,
|
||||
@@ -1671,7 +1670,7 @@ async def test_list_runs_with_multiple_statuses(server: SyncServer, sarah_agent,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
run_running = await server.run_manager.create_run(
|
||||
await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
status=RunStatus.running,
|
||||
agent_id=sarah_agent.id,
|
||||
@@ -1679,7 +1678,7 @@ async def test_list_runs_with_multiple_statuses(server: SyncServer, sarah_agent,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
run_completed = await server.run_manager.create_run(
|
||||
await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
status=RunStatus.completed,
|
||||
agent_id=sarah_agent.id,
|
||||
@@ -1687,7 +1686,7 @@ async def test_list_runs_with_multiple_statuses(server: SyncServer, sarah_agent,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
run_failed = await server.run_manager.create_run(
|
||||
await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
status=RunStatus.failed,
|
||||
agent_id=sarah_agent.id,
|
||||
|
||||
@@ -387,7 +387,7 @@ async def test_create_sources_with_same_name_raises_error(server: SyncServer, de
|
||||
metadata={"type": "medical"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
await server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
|
||||
# Attempting to create another source with the same name should raise an IntegrityError
|
||||
source_pydantic = PydanticSource(
|
||||
@@ -1120,7 +1120,7 @@ async def test_file_status_invalid_transitions(server, default_user, default_sou
|
||||
)
|
||||
created = await server.file_manager.create_file(file_metadata=meta, actor=default_user)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid state transition.*pending.*COMPLETED"):
|
||||
with pytest.raises(ValueError, match=r"Invalid state transition.*pending.*COMPLETED"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1142,7 +1142,7 @@ async def test_file_status_invalid_transitions(server, default_user, default_sou
|
||||
processing_status=FileProcessingStatus.PARSING,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid state transition.*parsing.*COMPLETED"):
|
||||
with pytest.raises(ValueError, match=r"Invalid state transition.*parsing.*COMPLETED"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created2.id,
|
||||
actor=default_user,
|
||||
@@ -1159,7 +1159,7 @@ async def test_file_status_invalid_transitions(server, default_user, default_sou
|
||||
)
|
||||
created3 = await server.file_manager.create_file(file_metadata=meta3, actor=default_user)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid state transition.*pending.*EMBEDDING"):
|
||||
with pytest.raises(ValueError, match=r"Invalid state transition.*pending.*EMBEDDING"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created3.id,
|
||||
actor=default_user,
|
||||
@@ -1186,14 +1186,14 @@ async def test_file_status_terminal_states(server, default_user, default_source)
|
||||
await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.COMPLETED)
|
||||
|
||||
# Cannot transition from COMPLETED to any state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state completed"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state completed"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
processing_status=FileProcessingStatus.EMBEDDING,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state completed"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state completed"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1219,7 +1219,7 @@ async def test_file_status_terminal_states(server, default_user, default_source)
|
||||
)
|
||||
|
||||
# Cannot transition from ERROR to any state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state error"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state error"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created2.id,
|
||||
actor=default_user,
|
||||
@@ -1313,7 +1313,7 @@ async def test_file_status_terminal_state_non_status_updates(server, default_use
|
||||
await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.COMPLETED)
|
||||
|
||||
# Cannot update chunks_embedded in COMPLETED state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state completed"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state completed"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1321,7 +1321,7 @@ async def test_file_status_terminal_state_non_status_updates(server, default_use
|
||||
)
|
||||
|
||||
# Cannot update total_chunks in COMPLETED state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state completed"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state completed"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1329,7 +1329,7 @@ async def test_file_status_terminal_state_non_status_updates(server, default_use
|
||||
)
|
||||
|
||||
# Cannot update error_message in COMPLETED state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state completed"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state completed"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1353,7 +1353,7 @@ async def test_file_status_terminal_state_non_status_updates(server, default_use
|
||||
)
|
||||
|
||||
# Cannot update chunks_embedded in ERROR state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state error"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state error"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created2.id,
|
||||
actor=default_user,
|
||||
@@ -1399,7 +1399,7 @@ async def test_file_status_race_condition_prevention(server, default_user, defau
|
||||
|
||||
# Try to continue with EMBEDDING as if error didn't happen (race condition)
|
||||
# This should fail because file is in ERROR state
|
||||
with pytest.raises(ValueError, match="Cannot update.*terminal state error"):
|
||||
with pytest.raises(ValueError, match=r"Cannot update.*terminal state error"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1424,7 +1424,7 @@ async def test_file_status_backwards_transitions(server, default_user, default_s
|
||||
await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING)
|
||||
|
||||
# Cannot go back to PARSING
|
||||
with pytest.raises(ValueError, match="Invalid state transition.*embedding.*PARSING"):
|
||||
with pytest.raises(ValueError, match=r"Invalid state transition.*embedding.*PARSING"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
@@ -1432,7 +1432,7 @@ async def test_file_status_backwards_transitions(server, default_user, default_s
|
||||
)
|
||||
|
||||
# Cannot go back to PENDING
|
||||
with pytest.raises(ValueError, match="Cannot transition to PENDING state.*PENDING is only valid as initial state"):
|
||||
with pytest.raises(ValueError, match=r"Cannot transition to PENDING state.*PENDING is only valid as initial state"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
|
||||
@@ -1945,8 +1945,8 @@ def test_function():
|
||||
source_code=source_code,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
with pytest.raises(ValueError):
|
||||
await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
|
||||
|
||||
async def test_error_on_create_tool_with_name_conflict(server: SyncServer, default_user, default_organization):
|
||||
|
||||
@@ -17,7 +17,7 @@ from letta.server.server import SyncServer
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def truncate_database():
|
||||
from letta.server.db import db_context
|
||||
from letta.server.db import db_context # type: ignore[attr-defined]
|
||||
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestSchemaValidator:
|
||||
}
|
||||
|
||||
# This should actually be STRICT_COMPLIANT since empty arrays with defined items are OK
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
status, _reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
def test_array_without_constraints_invalid(self):
|
||||
|
||||
@@ -111,7 +111,7 @@ async def test_insert_archival_memories_concurrent(client):
|
||||
cdf_y = np.arange(1, len(durs_sorted) + 1) / len(durs_sorted)
|
||||
|
||||
# Plot all 6 subplots
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
_fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
axs = axes.ravel()
|
||||
|
||||
# 1) Kickoff timeline
|
||||
|
||||
@@ -187,7 +187,7 @@ def get_attr(obj, attr):
|
||||
return getattr(obj, attr, None)
|
||||
|
||||
|
||||
def create_stdio_server_request(server_name: str, command: str = "npx", args: List[str] = None) -> Dict[str, Any]:
|
||||
def create_stdio_server_request(server_name: str, command: str = "npx", args: List[str] | None = None) -> Dict[str, Any]:
|
||||
"""Create a stdio MCP server configuration object.
|
||||
|
||||
Returns a dict with server_name and config following CreateMCPServerRequest schema.
|
||||
@@ -203,7 +203,7 @@ def create_stdio_server_request(server_name: str, command: str = "npx", args: Li
|
||||
}
|
||||
|
||||
|
||||
def create_sse_server_request(server_name: str, server_url: str = None) -> Dict[str, Any]:
|
||||
def create_sse_server_request(server_name: str, server_url: str | None = None) -> Dict[str, Any]:
|
||||
"""Create an SSE MCP server configuration object.
|
||||
|
||||
Returns a dict with server_name and config following CreateMCPServerRequest schema.
|
||||
@@ -220,7 +220,7 @@ def create_sse_server_request(server_name: str, server_url: str = None) -> Dict[
|
||||
}
|
||||
|
||||
|
||||
def create_streamable_http_server_request(server_name: str, server_url: str = None) -> Dict[str, Any]:
|
||||
def create_streamable_http_server_request(server_name: str, server_url: str | None = None) -> Dict[str, Any]:
|
||||
"""Create a streamable HTTP MCP server configuration object.
|
||||
|
||||
Returns a dict with server_name and config following CreateMCPServerRequest schema.
|
||||
@@ -508,7 +508,7 @@ def test_invalid_server_type(client: Letta):
|
||||
client.mcp_servers.create(**invalid_config)
|
||||
# If we get here without an exception, the test should fail
|
||||
assert False, "Expected an error when creating server with missing required fields"
|
||||
except (BadRequestError, UnprocessableEntityError, TypeError, ValueError) as e:
|
||||
except (BadRequestError, UnprocessableEntityError, TypeError, ValueError):
|
||||
# Expected to fail - this is good
|
||||
test_passed = True
|
||||
|
||||
|
||||
@@ -220,7 +220,7 @@ def test_passage_search_basic(client: Letta, enable_turbopuffer):
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
@@ -282,7 +282,7 @@ def test_passage_search_with_tags(client: Letta, enable_turbopuffer):
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
@@ -350,7 +350,7 @@ def test_passage_search_with_date_filters(client: Letta, enable_turbopuffer):
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
@@ -489,7 +489,7 @@ def test_passage_search_pagination(client: Letta, enable_turbopuffer):
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
@@ -554,11 +554,11 @@ def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
|
||||
# Clean up archives
|
||||
try:
|
||||
client.archives.delete(archive_id=archive1.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
client.archives.delete(archive_id=archive2.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
|
||||
@@ -355,7 +355,7 @@ def compare_in_context_message_id_remapping(server, og_agent: AgentState, copy_a
|
||||
remapped IDs but identical relevant content and order.
|
||||
"""
|
||||
# Serialize the original agent state
|
||||
result = server.agent_manager.serialize(agent_id=og_agent.id, actor=og_user)
|
||||
server.agent_manager.serialize(agent_id=og_agent.id, actor=og_user)
|
||||
|
||||
# Retrieve the in-context messages for both the original and the copy
|
||||
# Corrected typo: agent_id instead of agent_id
|
||||
|
||||
@@ -774,7 +774,7 @@ class TestFileExport:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_file_export(self, default_user, agent_serialization_manager, agent_with_files):
|
||||
"""Test basic file export functionality"""
|
||||
agent_id, source_id, file_id = agent_with_files
|
||||
agent_id, _source_id, _file_id = agent_with_files
|
||||
|
||||
exported = await agent_serialization_manager.export([agent_id], actor=default_user)
|
||||
|
||||
@@ -925,7 +925,7 @@ class TestFileExport:
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_content_inclusion_in_export(self, default_user, agent_serialization_manager, agent_with_files):
|
||||
"""Test that file content is included in export"""
|
||||
agent_id, source_id, file_id = agent_with_files
|
||||
agent_id, _source_id, _file_id = agent_with_files
|
||||
|
||||
exported = await agent_serialization_manager.export([agent_id], actor=default_user)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def mock_openai_server():
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def do_GET(self): # noqa: N802
|
||||
def do_GET(self):
|
||||
# Support OpenAI model listing used during provider sync.
|
||||
if self.path in ("/v1/models", "/models"):
|
||||
self._send_json(
|
||||
@@ -78,7 +78,7 @@ def mock_openai_server():
|
||||
|
||||
self._send_json(404, {"error": {"message": f"Not found: {self.path}"}})
|
||||
|
||||
def do_POST(self): # noqa: N802
|
||||
def do_POST(self):
|
||||
# Support embeddings endpoint
|
||||
if self.path not in ("/v1/embeddings", "/embeddings"):
|
||||
self._send_json(404, {"error": {"message": f"Not found: {self.path}"}})
|
||||
@@ -739,7 +739,7 @@ def test_initial_sequence(client: Letta):
|
||||
|
||||
# list messages
|
||||
messages = client.agents.messages.list(agent_id=agent.id).items
|
||||
response = client.agents.messages.create(
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreateParam(
|
||||
@@ -803,7 +803,7 @@ def test_attach_sleeptime_block(client: Letta):
|
||||
group_id = agent.multi_agent_group.id
|
||||
group = client.groups.retrieve(group_id=group_id)
|
||||
agent_ids = group.agent_ids
|
||||
sleeptime_id = [id for id in agent_ids if id != agent.id][0]
|
||||
sleeptime_id = next(id for id in agent_ids if id != agent.id)
|
||||
|
||||
# attach a new block
|
||||
block = client.blocks.create(label="test", value="test") # , project_id="test")
|
||||
@@ -891,7 +891,6 @@ def test_agent_generate_with_system_prompt(client: Letta, agent: AgentState):
|
||||
def test_agent_generate_with_model_override(client: Letta, agent: AgentState):
|
||||
"""Test generate endpoint with model override."""
|
||||
# Get the agent's current model
|
||||
original_model = agent.llm_config.model
|
||||
|
||||
# Use OpenAI model (more likely to be available in test environment)
|
||||
override_model_handle = "openai/gpt-4o-mini"
|
||||
|
||||
@@ -2,14 +2,13 @@ import httpx
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from letta.embeddings import GoogleEmbeddings # Adjust the import based on your module structure
|
||||
from letta.embeddings import GoogleEmbeddings # type: ignore[import-untyped] # Adjust the import based on your module structure
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from letta_client import CreateBlock, Letta as LettaSDKClient, MessageCreate
|
||||
|
||||
SERVER_PORT = 8283
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_agents(client: Letta) -> List[AgentState]:
|
||||
for agent in agents:
|
||||
try:
|
||||
client.agents.delete(agent.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
from unittest.mock import AsyncMock, patch
|
||||
@@ -769,7 +770,7 @@ def _assert_descending_order(messages):
|
||||
if len(messages) <= 1:
|
||||
return True
|
||||
|
||||
for prev, next in zip(messages[:-1], messages[1:]):
|
||||
for prev, next in itertools.pairwise(messages):
|
||||
assert prev.created_at >= next.created_at, (
|
||||
f"Order violation: {prev.id} ({prev.created_at}) followed by {next.id} ({next.created_at})"
|
||||
)
|
||||
|
||||
@@ -101,7 +101,7 @@ async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mo
|
||||
a ValueError is raised.
|
||||
"""
|
||||
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
|
||||
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
|
||||
with pytest.raises(ValueError, match=r"Agent mappings for messages and tools must use the same agent_ids."):
|
||||
await anthropic_client.send_llm_batch_request_async(
|
||||
AgentType.memgpt_agent, mock_agent_messages, mismatched_tools, mock_agent_llm_config
|
||||
)
|
||||
|
||||
@@ -4,10 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.llm_api.minimax_client import MINIMAX_BASE_URL, MiniMaxClient
|
||||
from letta.llm_api.minimax_client import MiniMaxClient
|
||||
from letta.schemas.enums import AgentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# MiniMax API base URL
|
||||
MINIMAX_BASE_URL = "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
class TestMiniMaxClient:
|
||||
"""Tests for MiniMaxClient."""
|
||||
@@ -55,7 +58,7 @@ class TestMiniMaxClient:
|
||||
# Mock BYOK to return no override
|
||||
self.client.get_byok_overrides = MagicMock(return_value=(None, None, None))
|
||||
|
||||
client = self.client._get_anthropic_client(self.llm_config, async_client=False)
|
||||
self.client._get_anthropic_client(self.llm_config, async_client=False)
|
||||
|
||||
mock_anthropic.Anthropic.assert_called_once_with(
|
||||
api_key="test-api-key",
|
||||
@@ -73,7 +76,7 @@ class TestMiniMaxClient:
|
||||
# Mock BYOK to return no override
|
||||
self.client.get_byok_overrides = MagicMock(return_value=(None, None, None))
|
||||
|
||||
client = self.client._get_anthropic_client(self.llm_config, async_client=True)
|
||||
self.client._get_anthropic_client(self.llm_config, async_client=True)
|
||||
|
||||
mock_anthropic.AsyncAnthropic.assert_called_once_with(
|
||||
api_key="test-api-key",
|
||||
@@ -100,7 +103,7 @@ class TestMiniMaxClientTemperatureClamping:
|
||||
"""Verify build_request_data is called for temperature clamping."""
|
||||
# This is a basic test to ensure the method exists and can be called
|
||||
mock_build.return_value = {"temperature": 0.7}
|
||||
result = self.client.build_request_data(
|
||||
self.client.build_request_data(
|
||||
agent_type=AgentType.letta_v1_agent,
|
||||
messages=[],
|
||||
llm_config=self.llm_config,
|
||||
@@ -214,7 +217,7 @@ class TestMiniMaxClientUsesNonBetaAPI:
|
||||
mock_anthropic_client.messages.create.return_value = mock_response
|
||||
mock_get_client.return_value = mock_anthropic_client
|
||||
|
||||
result = client.request({"model": "MiniMax-M2.1"}, llm_config)
|
||||
client.request({"model": "MiniMax-M2.1"}, llm_config)
|
||||
|
||||
# Verify messages.create was called (not beta.messages.create)
|
||||
mock_anthropic_client.messages.create.assert_called_once()
|
||||
@@ -239,7 +242,7 @@ class TestMiniMaxClientUsesNonBetaAPI:
|
||||
mock_anthropic_client.messages.create.return_value = mock_response
|
||||
mock_get_client.return_value = mock_anthropic_client
|
||||
|
||||
result = await client.request_async({"model": "MiniMax-M2.1"}, llm_config)
|
||||
await client.request_async({"model": "MiniMax-M2.1"}, llm_config)
|
||||
|
||||
# Verify messages.create was called (not beta.messages.create)
|
||||
mock_anthropic_client.messages.create.assert_called_once()
|
||||
@@ -261,7 +264,7 @@ class TestMiniMaxClientUsesNonBetaAPI:
|
||||
mock_anthropic_client.messages.create.return_value = mock_stream
|
||||
mock_get_client.return_value = mock_anthropic_client
|
||||
|
||||
result = await client.stream_async({"model": "MiniMax-M2.1"}, llm_config)
|
||||
await client.stream_async({"model": "MiniMax-M2.1"}, llm_config)
|
||||
|
||||
# Verify messages.create was called (not beta.messages.create)
|
||||
mock_anthropic_client.messages.create.assert_called_once()
|
||||
|
||||
@@ -542,7 +542,7 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
|
||||
|
||||
try:
|
||||
# Message 1: Establish cache
|
||||
response1 = await async_client.agents.messages.create(
|
||||
await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
@@ -23,4 +23,4 @@ async def test_redis_client():
|
||||
assert await redis_client.smismember(k, "invalid") == 0
|
||||
assert await redis_client.smismember(k, v[0]) == 1
|
||||
assert await redis_client.smismember(k, v[:2]) == [1, 1]
|
||||
assert await redis_client.smismember(k, v[2:] + ["invalid"]) == [1, 0]
|
||||
assert await redis_client.smismember(k, [*v[2:], "invalid"]) == [1, 0]
|
||||
|
||||
@@ -6,7 +6,7 @@ import textwrap
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Type
|
||||
from typing import ClassVar, List, Type
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@@ -379,7 +379,7 @@ def test_add_and_manage_tags_for_agent(client: LettaSDKClient):
|
||||
assert len(agent.tags) == 0
|
||||
|
||||
# Step 1: Add multiple tags to the agent
|
||||
updated_agent = client.agents.update(agent_id=agent.id, tags=tags_to_add)
|
||||
client.agents.update(agent_id=agent.id, tags=tags_to_add)
|
||||
|
||||
# Add small delay to ensure tags are persisted
|
||||
time.sleep(0.1)
|
||||
@@ -397,7 +397,7 @@ def test_add_and_manage_tags_for_agent(client: LettaSDKClient):
|
||||
|
||||
# Step 4: Delete a specific tag from the agent and verify its removal
|
||||
tag_to_delete = tags_to_add.pop()
|
||||
updated_agent = client.agents.update(agent_id=agent.id, tags=tags_to_add)
|
||||
client.agents.update(agent_id=agent.id, tags=tags_to_add)
|
||||
|
||||
# Verify the tag is removed from the agent's tags - explicitly request tags
|
||||
remaining_tags = client.agents.retrieve(agent_id=agent.id, include=["agent.tags"]).tags
|
||||
@@ -426,7 +426,7 @@ def test_reset_messages(client: LettaSDKClient):
|
||||
|
||||
try:
|
||||
# Send a message
|
||||
response = client.agents.messages.create(
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreateParam(role="user", content="Hello")],
|
||||
)
|
||||
@@ -542,7 +542,6 @@ def test_list_files_for_agent(client: LettaSDKClient):
|
||||
raise RuntimeError(f"File {file_metadata.id} not found")
|
||||
if file_metadata.processing_status == "error":
|
||||
raise RuntimeError(f"File processing failed: {getattr(file_metadata, 'error_message', 'Unknown error')}")
|
||||
test_file = file_metadata
|
||||
|
||||
agent = client.agents.create(
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test")],
|
||||
@@ -604,7 +603,7 @@ def test_modify_message(client: LettaSDKClient):
|
||||
|
||||
try:
|
||||
# Send a message
|
||||
response = client.agents.messages.create(
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreateParam(role="user", content="Original message")],
|
||||
)
|
||||
@@ -987,11 +986,6 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState):
|
||||
|
||||
def test_agent_creation(client: LettaSDKClient):
|
||||
"""Test that block IDs are properly attached when creating an agent."""
|
||||
sleeptime_agent_system = """
|
||||
You are a helpful agent. You will be provided with a list of memory blocks and a user preferences block.
|
||||
You should use the memory blocks to remember information about the user and their preferences.
|
||||
You should also use the user preferences block to remember information about the user's preferences.
|
||||
"""
|
||||
|
||||
# Create a test block that will represent user preferences
|
||||
user_preferences_block = client.blocks.create(
|
||||
@@ -1255,7 +1249,7 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
|
||||
name: str = "manage_inventory"
|
||||
args_schema: Type[BaseModel] = InventoryEntryData
|
||||
description: str = "Update inventory catalogue with a new data entry"
|
||||
tags: List[str] = ["inventory", "shop"]
|
||||
tags: ClassVar[List[str]] = ["inventory", "shop"]
|
||||
|
||||
def run(self, data: InventoryEntry, quantity_change: int) -> bool:
|
||||
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
|
||||
@@ -2381,7 +2375,7 @@ def test_create_agent_with_tools(client: LettaSDKClient) -> None:
|
||||
name: str = "manage_inventory"
|
||||
args_schema: Type[BaseModel] = InventoryEntryData
|
||||
description: str = "Update inventory catalogue with a new data entry"
|
||||
tags: List[str] = ["inventory", "shop"]
|
||||
tags: ClassVar[List[str]] = ["inventory", "shop"]
|
||||
|
||||
def run(self, data: InventoryEntry, quantity_change: int) -> bool:
|
||||
"""
|
||||
|
||||
@@ -83,7 +83,7 @@ async def custom_anthropic_provider(server: SyncServer, user_id: str):
|
||||
|
||||
@pytest.fixture
|
||||
async def agent(server: SyncServer, user: User):
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
await server.user_manager.get_actor_or_default_async()
|
||||
agent = await server.create_agent_async(
|
||||
CreateAgent(
|
||||
agent_type="memgpt_v2_agent",
|
||||
@@ -129,7 +129,6 @@ async def test_messages_with_provider_override(server: SyncServer, custom_anthro
|
||||
run_id=run.id,
|
||||
)
|
||||
usage = response.usage
|
||||
messages = response.messages
|
||||
|
||||
get_messages_response = await server.message_manager.list_messages(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
||||
|
||||
@@ -228,7 +227,6 @@ async def test_messages_with_provider_override_legacy_agent(server: SyncServer,
|
||||
run_id=run.id,
|
||||
)
|
||||
usage = response.usage
|
||||
messages = response.messages
|
||||
|
||||
get_messages_response = await server.message_manager.list_messages(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
||||
|
||||
|
||||
@@ -110,7 +110,6 @@ async def test_sync_base_providers_handles_race_condition(default_user, provider
|
||||
|
||||
# Mock a race condition: list returns empty, but create fails with UniqueConstraintViolation
|
||||
original_list = provider_manager.list_providers_async
|
||||
original_create = provider_manager.create_provider_async
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
@@ -2030,14 +2029,14 @@ async def test_get_enabled_providers_async_queries_database(default_user, provid
|
||||
api_key="sk-test-key",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-provider-{test_id}",
|
||||
provider_type=ProviderType.anthropic,
|
||||
api_key="sk-test-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Create server instance - importantly, don't set _enabled_providers
|
||||
# This ensures we're testing database queries, not in-memory list
|
||||
@@ -2182,7 +2181,7 @@ async def test_byok_provider_api_key_stored_in_db(default_user, provider_manager
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-should-be-stored",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Retrieve the provider from database
|
||||
providers = await provider_manager.list_providers_async(name=f"test-byok-with-key-{test_id}", actor=default_user)
|
||||
@@ -2573,7 +2572,7 @@ async def test_byok_provider_last_synced_triggers_sync_when_null(default_user, p
|
||||
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
# List BYOK models - should trigger sync because last_synced is null
|
||||
byok_models = await server.list_llm_models_async(
|
||||
await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
@@ -84,7 +84,7 @@ def agent_factory(client: Letta):
|
||||
for agent_state in created_agents:
|
||||
try:
|
||||
client.agents.delete(agent_state.id)
|
||||
except:
|
||||
except Exception:
|
||||
pass # Agent might have already been deleted
|
||||
|
||||
|
||||
|
||||
@@ -89,9 +89,9 @@ def client() -> LettaSDKClient:
|
||||
|
||||
@pytest.fixture
|
||||
def agent_state(disable_pinecone, client: LettaSDKClient):
|
||||
open_file_tool = list(client.tools.list(name="open_files"))[0]
|
||||
search_files_tool = list(client.tools.list(name="semantic_search_files"))[0]
|
||||
grep_tool = list(client.tools.list(name="grep_files"))[0]
|
||||
open_file_tool = next(iter(client.tools.list(name="open_files")))
|
||||
search_files_tool = next(iter(client.tools.list(name="semantic_search_files")))
|
||||
grep_tool = next(iter(client.tools.list(name="grep_files")))
|
||||
|
||||
agent_state = client.agents.create(
|
||||
name="test_sources_agent",
|
||||
@@ -745,13 +745,13 @@ def test_duplicate_file_renaming(disable_pinecone, disable_turbopuffer, client:
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
first_file = client.folders.files.upload(folder_id=source.id, file=f)
|
||||
client.folders.files.upload(folder_id=source.id, file=f)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
second_file = client.folders.files.upload(folder_id=source.id, file=f)
|
||||
client.folders.files.upload(folder_id=source.id, file=f)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
third_file = client.folders.files.upload(folder_id=source.id, file=f)
|
||||
client.folders.files.upload(folder_id=source.id, file=f)
|
||||
|
||||
# Get all uploaded files
|
||||
files = list(client.folders.files.list(folder_id=source.id, limit=10))
|
||||
@@ -821,7 +821,7 @@ def test_duplicate_file_handling_replace(disable_pinecone, disable_turbopuffer,
|
||||
f.write(replacement_content)
|
||||
|
||||
# Upload replacement file with REPLACE duplicate handling
|
||||
replacement_file = upload_file_and_wait(client, source.id, temp_file_path, duplicate_handling="replace")
|
||||
upload_file_and_wait(client, source.id, temp_file_path, duplicate_handling="replace")
|
||||
|
||||
# Verify we still have only 1 file (replacement, not addition)
|
||||
files_after_replace = list(client.folders.files.list(folder_id=source.id, limit=10))
|
||||
|
||||
@@ -77,7 +77,7 @@ def test_get_allowed_tool_names_no_matching_rule_error():
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule])
|
||||
|
||||
solver.register_tool_call(UNRECOGNIZED_TOOL)
|
||||
with pytest.raises(ValueError, match="No valid tools found based on tool rules."):
|
||||
with pytest.raises(ValueError, match=r"No valid tools found based on tool rules."):
|
||||
solver.get_allowed_tool_names(set(), error_on_empty=True)
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_conditional_tool_rule():
|
||||
|
||||
|
||||
def test_invalid_conditional_tool_rule():
|
||||
with pytest.raises(ValueError, match="Conditional tool rule must have at least one child tool."):
|
||||
with pytest.raises(ValueError, match=r"Conditional tool rule must have at least one child tool."):
|
||||
ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
|
||||
|
||||
|
||||
@@ -402,7 +402,7 @@ def test_cross_type_hash_distinguishes_types(a, b):
|
||||
)
|
||||
def test_equality_with_non_rule_objects(rule):
|
||||
assert rule != object()
|
||||
assert rule != None # noqa: E711
|
||||
assert rule != None
|
||||
|
||||
|
||||
def test_conditional_tool_rule_mapping_order_and_hash():
|
||||
|
||||
@@ -115,7 +115,7 @@ def test_derive_openai_json_schema():
|
||||
# Collect results and check for failures
|
||||
for schema_name, result in results:
|
||||
try:
|
||||
schema_name_result, success = result.get(timeout=60) # Wait for the result with timeout
|
||||
_schema_name_result, success = result.get(timeout=60) # Wait for the result with timeout
|
||||
assert success, f"Test for {schema_name} failed"
|
||||
print(f"Test for {schema_name} passed")
|
||||
except Exception as e:
|
||||
|
||||
@@ -492,7 +492,7 @@ def test_line_chunker_out_of_range_start():
|
||||
chunker = LineChunker()
|
||||
|
||||
# Test with start beyond file length - should raise ValueError
|
||||
with pytest.raises(ValueError, match="File test.py has only 3 lines, but requested offset 6 is out of range"):
|
||||
with pytest.raises(ValueError, match=r"File test.py has only 3 lines, but requested offset 6 is out of range"):
|
||||
chunker.chunk_text(file, start=5, end=6, validate_range=True)
|
||||
|
||||
|
||||
@@ -530,7 +530,7 @@ def test_line_chunker_edge_case_single_line():
|
||||
assert "1: only line" in result[1]
|
||||
|
||||
# Test out of range for single line file - should raise error
|
||||
with pytest.raises(ValueError, match="File single.py has only 1 lines, but requested offset 2 is out of range"):
|
||||
with pytest.raises(ValueError, match=r"File single.py has only 1 lines, but requested offset 2 is out of range"):
|
||||
chunker.chunk_text(file, start=1, end=2, validate_range=True)
|
||||
|
||||
|
||||
@@ -540,7 +540,7 @@ def test_line_chunker_validation_disabled_allows_out_of_range():
|
||||
chunker = LineChunker()
|
||||
|
||||
# Test 1: Out of bounds start should always raise error, even with validation disabled
|
||||
with pytest.raises(ValueError, match="File test.py has only 3 lines, but requested offset 6 is out of range"):
|
||||
with pytest.raises(ValueError, match=r"File test.py has only 3 lines, but requested offset 6 is out of range"):
|
||||
chunker.chunk_text(file, start=5, end=10, validate_range=False)
|
||||
|
||||
# Test 2: With validation disabled, start >= end should be allowed (but gives empty result)
|
||||
@@ -561,7 +561,7 @@ def test_line_chunker_only_start_parameter():
|
||||
assert "3: line3" in result[2]
|
||||
|
||||
# Test start at end of file - should raise error
|
||||
with pytest.raises(ValueError, match="File test.py has only 3 lines, but requested offset 4 is out of range"):
|
||||
with pytest.raises(ValueError, match=r"File test.py has only 3 lines, but requested offset 4 is out of range"):
|
||||
chunker.chunk_text(file, start=3, validate_range=True)
|
||||
|
||||
|
||||
@@ -653,10 +653,10 @@ def test_validate_function_response_strict_mode_none():
|
||||
|
||||
def test_validate_function_response_strict_mode_violation():
|
||||
"""Test strict mode raises ValueError for non-string/None types"""
|
||||
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: int"):
|
||||
with pytest.raises(ValueError, match=r"Strict mode violation. Function returned type: int"):
|
||||
validate_function_response(42, return_char_limit=100, strict=True)
|
||||
|
||||
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: dict"):
|
||||
with pytest.raises(ValueError, match=r"Strict mode violation. Function returned type: dict"):
|
||||
validate_function_response({"key": "value"}, return_char_limit=100, strict=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user