feat: check run completion in send message tests (#5030)

This commit is contained in:
cthomas
2025-09-30 15:00:45 -07:00
committed by Caren Thomas
parent 2916095e86
commit cd900a6f4d
4 changed files with 30 additions and 5 deletions

View File

@@ -568,7 +568,6 @@ class LettaAgentV3(LettaAgentV2):
for message in messages_to_persist:
if message.run_id is None:
message.run_id = run_id
print("MESSSAGE RUN ID", message.run_id, run_id)
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist,
@@ -743,7 +742,6 @@ class LettaAgentV3(LettaAgentV2):
for message in messages_to_persist:
if message.run_id is None:
message.run_id = run_id
print("MESSSAGE RUN ID", message.run_id, run_id)
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id

View File

@@ -9,6 +9,8 @@ from typing import AsyncIterator, Dict, List, Optional
from letta.data_sources.redis_client import AsyncRedisClient
from letta.log import get_logger
from letta.schemas.enums import RunStatus
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.run import RunUpdate
from letta.schemas.user import User
from letta.services.run_manager import RunManager
from letta.utils import safe_create_task
@@ -236,8 +238,11 @@ async def create_background_stream_processor(
# error_chunk = {"error": {"message": str(e)}}
# Mark run_id terminal state
if run_manager and actor:
await run_manager.safe_update_run_status_async(
run_id=run_id, new_status=RunStatus.failed, actor=actor, metadata={"error": str(e)}
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value),
actor=actor,
metadata={"error": str(e)},
)
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
@@ -245,6 +250,12 @@ async def create_background_stream_processor(
finally:
if should_stop_writer:
await writer.stop()
if run_manager and actor:
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn.value),
actor=actor,
)
async def redis_sse_stream_generator(

View File

@@ -1333,6 +1333,16 @@ async def send_message_streaming(
async for chunk in stream:
yield chunk
if run:
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
await runs_manager.update_run_by_id_async(
run_id=run.id,
update=RunUpdate(status=RunStatus.completed, stop_reason=agent_loop.stop_reason.stop_reason.value),
actor=actor,
)
except LLMTimeoutError as e:
error_data = {
"error": {"type": "llm_timeout", "message": "The LLM request timed out. Please try again.", "detail": str(e)}

View File

@@ -33,7 +33,7 @@ from letta_client.types import (
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType
from letta.schemas.enums import AgentType, JobStatus
from letta.schemas.letta_ping import LettaPing
from letta.schemas.llm_config import LLMConfig
@@ -293,6 +293,7 @@ async def test_greeting(
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = response.messages
run_id = messages[0].run_id
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
@@ -301,6 +302,7 @@ async def test_greeting(
run = await wait_for_run_completion(client, run.id)
messages = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages if m.message_type != "user_message"]
run_id = run.id
else:
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
@@ -324,3 +326,7 @@ async def test_greeting(
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config)
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.completed