diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 240b286b..1490f775 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -33,7 +33,7 @@ from letta.otel.metric_registry import MetricRegistry from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.agent_file import AgentFileSchema from letta.schemas.block import Block, BlockUpdate -from letta.schemas.enums import RunStatus +from letta.schemas.enums import AgentType, RunStatus from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles from letta.schemas.group import Group from letta.schemas.job import LettaRequestConfig @@ -1659,7 +1659,7 @@ async def send_message_async( agent_state = await server.agent_manager.get_agent_by_id_async( agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"] ) - if agent_state.multi_agent_group is None: + if agent_state.multi_agent_group is None and agent_state.agent_type != AgentType.letta_v1_agent: temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor) await temporal_agent.step( input_messages=request.messages, diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 37ed101f..e73da147 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -178,6 +178,20 @@ async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = Fa return [m for m in messages if m is not None] +async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: + start = time.time() + while True: + run = await client.runs.retrieve(run_id) + if run.status == "completed": + return run + if run.status == "failed": + print(run) + raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") + time.sleep(interval) + + # ------------------------------ # Fixtures # ------------------------------ @@ -261,8 +275,8 @@ async def agent_state(client: AsyncLetta) -> AgentState: TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS], ) -@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens"]) -@pytest.mark.asyncio(scope="function") +@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "async"]) +@pytest.mark.asyncio(loop_scope="function") async def test_greeting( disable_e2b_api_key: Any, client: AsyncLetta, @@ -279,6 +293,14 @@ async def test_greeting( messages=USER_MESSAGE_FORCE_REPLY, ) messages = response.messages + elif send_type == "async": + run = await client.agents.messages.create_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + run = await wait_for_run_completion(client, run.id) + messages = await client.runs.messages.list(run_id=run.id) + messages = [m for m in messages if m.message_type != "user_message"] else: response = client.agents.messages.create_stream( agent_id=agent_state.id,