feat: check run completion in send message tests (#5030)
This commit is contained in:
@@ -568,7 +568,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
for message in messages_to_persist:
|
||||
if message.run_id is None:
|
||||
message.run_id = run_id
|
||||
print("MESSSAGE RUN ID", message.run_id, run_id)
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist,
|
||||
@@ -743,7 +742,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
for message in messages_to_persist:
|
||||
if message.run_id is None:
|
||||
message.run_id = run_id
|
||||
print("MESSSAGE RUN ID", message.run_id, run_id)
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||||
|
||||
@@ -9,6 +9,8 @@ from typing import AsyncIterator, Dict, List, Optional
|
||||
from letta.data_sources.redis_client import AsyncRedisClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.run import RunUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.utils import safe_create_task
|
||||
@@ -236,8 +238,11 @@ async def create_background_stream_processor(
|
||||
# error_chunk = {"error": {"message": str(e)}}
|
||||
# Mark run_id terminal state
|
||||
if run_manager and actor:
|
||||
await run_manager.safe_update_run_status_async(
|
||||
run_id=run_id, new_status=RunStatus.failed, actor=actor, metadata={"error": str(e)}
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value),
|
||||
actor=actor,
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
|
||||
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
|
||||
@@ -245,6 +250,12 @@ async def create_background_stream_processor(
|
||||
finally:
|
||||
if should_stop_writer:
|
||||
await writer.stop()
|
||||
if run_manager and actor:
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn.value),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
async def redis_sse_stream_generator(
|
||||
|
||||
@@ -1333,6 +1333,16 @@ async def send_message_streaming(
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
if run:
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run.id,
|
||||
update=RunUpdate(status=RunStatus.completed, stop_reason=agent_loop.stop_reason.stop_reason.value),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
except LLMTimeoutError as e:
|
||||
error_data = {
|
||||
"error": {"type": "llm_timeout", "message": "The LLM request timed out. Please try again.", "detail": str(e)}
|
||||
|
||||
@@ -33,7 +33,7 @@ from letta_client.types import (
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import AgentType
|
||||
from letta.schemas.enums import AgentType, JobStatus
|
||||
from letta.schemas.letta_ping import LettaPing
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
@@ -293,6 +293,7 @@ async def test_greeting(
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages = response.messages
|
||||
run_id = messages[0].run_id
|
||||
elif send_type == "async":
|
||||
run = await client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
@@ -301,6 +302,7 @@ async def test_greeting(
|
||||
run = await wait_for_run_completion(client, run.id)
|
||||
messages = await client.runs.messages.list(run_id=run.id)
|
||||
messages = [m for m in messages if m.message_type != "user_message"]
|
||||
run_id = run.id
|
||||
else:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
@@ -324,3 +326,7 @@ async def test_greeting(
|
||||
|
||||
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
assert run_id is not None
|
||||
run = await client.runs.retrieve(run_id=run_id)
|
||||
assert run.status == JobStatus.completed
|
||||
|
||||
Reference in New Issue
Block a user