feat: add messages.create_async test for new agent loop (#5024)

feat: add async test for new agent loop
This commit is contained in:
cthomas
2025-09-30 13:14:04 -07:00
committed by Caren Thomas
parent 2af3130be1
commit e248ac27e2
2 changed files with 26 additions and 4 deletions

View File

@@ -33,7 +33,7 @@ from letta.otel.metric_registry import MetricRegistry
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.agent_file import AgentFileSchema
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.enums import RunStatus
from letta.schemas.enums import AgentType, RunStatus
from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles
from letta.schemas.group import Group
from letta.schemas.job import LettaRequestConfig
@@ -1659,7 +1659,7 @@ async def send_message_async(
agent_state = await server.agent_manager.get_agent_by_id_async(
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
)
if agent_state.multi_agent_group is None:
if agent_state.multi_agent_group is None and agent_state.agent_type != AgentType.letta_v1_agent:
temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor)
await temporal_agent.step(
input_messages=request.messages,

View File

@@ -178,6 +178,20 @@ async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = Fa
return [m for m in messages if m is not None]
async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
run = await client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "failed":
print(run)
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
time.sleep(interval)
# ------------------------------
# Fixtures
# ------------------------------
@@ -261,8 +275,8 @@ async def agent_state(client: AsyncLetta) -> AgentState:
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens"])
@pytest.mark.asyncio(scope="function")
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "async"])
@pytest.mark.asyncio(loop_scope="function")
async def test_greeting(
disable_e2b_api_key: Any,
client: AsyncLetta,
@@ -279,6 +293,14 @@ async def test_greeting(
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = response.messages
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
run = await wait_for_run_completion(client, run.id)
messages = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages if m.message_type != "user_message"]
else:
response = client.agents.messages.create_stream(
agent_id=agent_state.id,