feat: add messages.create_async test for new agent loop (#5024)
feat: add async test for new agent loop
This commit is contained in:
@@ -33,7 +33,7 @@ from letta.otel.metric_registry import MetricRegistry
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent_file import AgentFileSchema
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.enums import AgentType, RunStatus
|
||||
from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import LettaRequestConfig
|
||||
@@ -1659,7 +1659,7 @@ async def send_message_async(
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
)
|
||||
if agent_state.multi_agent_group is None:
|
||||
if agent_state.multi_agent_group is None and agent_state.agent_type != AgentType.letta_v1_agent:
|
||||
temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor)
|
||||
await temporal_agent.step(
|
||||
input_messages=request.messages,
|
||||
|
||||
@@ -178,6 +178,20 @@ async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = Fa
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
start = time.time()
|
||||
while True:
|
||||
run = await client.runs.retrieve(run_id)
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "failed":
|
||||
print(run)
|
||||
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
@@ -261,8 +275,8 @@ async def agent_state(client: AsyncLetta) -> AgentState:
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens"])
|
||||
@pytest.mark.asyncio(scope="function")
|
||||
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "async"])
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_greeting(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
@@ -279,6 +293,14 @@ async def test_greeting(
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages = response.messages
|
||||
elif send_type == "async":
|
||||
run = await client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
run = await wait_for_run_completion(client, run.id)
|
||||
messages = await client.runs.messages.list(run_id=run.id)
|
||||
messages = [m for m in messages if m.message_type != "user_message"]
|
||||
else:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
|
||||
Reference in New Issue
Block a user