diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 215c06e3..f7e6cef2 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -568,7 +568,6 @@ class LettaAgentV3(LettaAgentV2): for message in messages_to_persist: if message.run_id is None: message.run_id = run_id - print("MESSSAGE RUN ID", message.run_id, run_id) persisted_messages = await self.message_manager.create_many_messages_async( messages_to_persist, @@ -743,7 +742,6 @@ class LettaAgentV3(LettaAgentV2): for message in messages_to_persist: if message.run_id is None: message.run_id = run_id - print("MESSSAGE RUN ID", message.run_id, run_id) persisted_messages = await self.message_manager.create_many_messages_async( messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index 6f97085c..3d98205f 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -9,6 +9,8 @@ from typing import AsyncIterator, Dict, List, Optional from letta.data_sources.redis_client import AsyncRedisClient from letta.log import get_logger from letta.schemas.enums import RunStatus +from letta.schemas.letta_stop_reason import StopReasonType +from letta.schemas.run import RunUpdate from letta.schemas.user import User from letta.services.run_manager import RunManager from letta.utils import safe_create_task @@ -236,8 +238,11 @@ async def create_background_stream_processor( # error_chunk = {"error": {"message": str(e)}} # Mark run_id terminal state if run_manager and actor: - await run_manager.safe_update_run_status_async( - run_id=run_id, new_status=RunStatus.failed, actor=actor, metadata={"error": str(e)} + await run_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value), + actor=actor, + metadata={"error": str(e)}, ) error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"} @@ -245,6 +250,12 @@ async def create_background_stream_processor( finally: if should_stop_writer: await writer.stop() + if run_manager and actor: + await run_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn.value), + actor=actor, + ) async def redis_sse_stream_generator( diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 1490f775..26603ec2 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1333,6 +1333,16 @@ async def send_message_streaming( async for chunk in stream: yield chunk + if run: + runs_manager = RunManager() + from letta.schemas.enums import RunStatus + + await runs_manager.update_run_by_id_async( + run_id=run.id, + update=RunUpdate(status=RunStatus.completed, stop_reason=agent_loop.stop_reason.stop_reason.value), + actor=actor, + ) + except LLMTimeoutError as e: error_data = { "error": {"type": "llm_timeout", "message": "The LLM request timed out. Please try again.", "detail": str(e)} diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 4782618e..96590c62 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -33,7 +33,7 @@ from letta_client.types import ( from letta.log import get_logger from letta.schemas.agent import AgentState -from letta.schemas.enums import AgentType +from letta.schemas.enums import AgentType, JobStatus from letta.schemas.letta_ping import LettaPing from letta.schemas.llm_config import LLMConfig @@ -293,6 +293,7 @@ async def test_greeting( messages=USER_MESSAGE_FORCE_REPLY, ) messages = response.messages + run_id = messages[0].run_id elif send_type == "async": run = await client.agents.messages.create_async( agent_id=agent_state.id, @@ -301,6 +302,7 @@ async def test_greeting( run = await wait_for_run_completion(client, run.id) messages = await client.runs.messages.list(run_id=run.id) messages = [m for m in messages if m.message_type != "user_message"] + run_id = run.id else: response = client.agents.messages.create_stream( agent_id=agent_state.id, @@ -324,3 +326,7 @@ async def test_greeting( messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config) + + assert run_id is not None + run = await client.runs.retrieve(run_id=run_id) + assert run.status == JobStatus.completed