feat: make create_async route consistent with other message routes (#2877)

This commit is contained in:
cthomas
2025-06-19 13:51:51 -07:00
committed by GitHub
parent 56493de971
commit 99e112e486
17 changed files with 209 additions and 36 deletions

View File

@@ -1000,11 +1000,12 @@ class Agent(BaseAgent):
)
if job_id:
for message in all_new_messages:
self.job_manager.add_message_to_job(
job_id=job_id,
message_id=message.id,
actor=self.user,
)
if message.role != "user":
self.job_manager.add_message_to_job(
job_id=job_id,
message_id=message.id,
actor=self.user,
)
return AgentStepResponse(
messages=all_new_messages,

View File

@@ -50,7 +50,9 @@ class BaseAgent(ABC):
self.logger = get_logger(agent_id)
@abstractmethod
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
async def step(
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None
) -> LettaResponse:
"""
Main execution loop for the agent.
"""

View File

@@ -43,6 +43,7 @@ from letta.server.rest_api.utils import create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
@@ -66,6 +67,7 @@ class LettaAgent(BaseAgent):
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager,
actor: User,
step_manager: StepManager = NoopStepManager(),
@@ -81,6 +83,7 @@ class LettaAgent(BaseAgent):
# TODO: Make this more general, factorable
# Summarizer settings
self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
@@ -120,6 +123,7 @@ class LettaAgent(BaseAgent):
self,
input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -131,6 +135,7 @@ class LettaAgent(BaseAgent):
agent_state=agent_state,
input_messages=input_messages,
max_steps=max_steps,
run_id=run_id,
request_start_timestamp_ns=request_start_timestamp_ns,
)
return _create_letta_response(
@@ -193,7 +198,6 @@ class LettaAgent(BaseAgent):
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
@@ -302,6 +306,7 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
request_start_timestamp_ns: Optional[int] = None,
) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]:
"""
@@ -345,11 +350,11 @@ class LettaAgent(BaseAgent):
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
@@ -385,6 +390,7 @@ class LettaAgent(BaseAgent):
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
run_id=run_id,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -916,6 +922,7 @@ class LettaAgent(BaseAgent):
initial_messages: Optional[List[Message]] = None,
agent_step_span: Optional["Span"] = None,
is_final_step: Optional[bool] = None,
run_id: Optional[str] = None,
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
"""
Now that streaming is done, handle the final AI response.
@@ -1027,7 +1034,7 @@ class LettaAgent(BaseAgent):
# 5a. Persist Steps to DB
# Following agent loop to persist this before messages
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
# TODO (cliandy): determine what should match old loop w/provider_id
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
@@ -1039,7 +1046,7 @@ class LettaAgent(BaseAgent):
context_window_limit=agent_state.llm_config.context_window,
usage=usage,
provider_id=None,
job_id=None,
job_id=run_id,
step_id=step_id,
)
@@ -1065,6 +1072,13 @@ class LettaAgent(BaseAgent):
)
self.last_function_response = function_response
if run_id:
await self.job_manager.add_messages_to_job_async(
job_id=run_id,
message_ids=[message.id for message in persisted_messages if message.role != "user"],
actor=self.actor,
)
return persisted_messages, continue_stepping, stop_reason
@trace_method
@@ -1102,6 +1116,7 @@ class LettaAgent(BaseAgent):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
sandbox_env_vars=sandbox_env_vars,
actor=self.actor,

View File

@@ -63,6 +63,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
self,
input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -83,6 +84,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
@@ -92,6 +94,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
response = await foreground_agent.step(
input_messages=new_messages,
max_steps=max_steps,
run_id=run_id,
use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
)
@@ -170,6 +173,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
@@ -283,6 +287,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
@@ -296,6 +301,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
result = await sleeptime_agent.step(
input_messages=sleeptime_agent_messages,
use_assistant_message=use_assistant_message,
run_id=run_id,
)
# Update job status

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.orm.enums import JobType
from letta.schemas.enums import JobStatus
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import MessageType
class JobBase(OrmMetadataBase):
@@ -94,3 +95,6 @@ class LettaRequestConfig(BaseModel):
default=DEFAULT_MESSAGE_TOOL_KWARG,
description="The name of the message argument in the designated message tool.",
)
include_return_message_types: Optional[List[MessageType]] = Field(
default=None, description="Only return specified message types in the response. If `None` (default) returns all messages."
)

View File

@@ -39,6 +39,10 @@ class LettaStreamingRequest(LettaRequest):
)
class LettaAsyncRequest(LettaRequest):
callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes")
class LettaBatchRequest(LettaRequest):
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")

View File

@@ -25,7 +25,7 @@ from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import MessageCreate
@@ -707,6 +707,7 @@ async def send_message(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
@@ -793,6 +794,7 @@ async def send_message_streaming(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
@@ -884,6 +886,7 @@ async def process_message_background(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
@@ -893,6 +896,7 @@ async def process_message_background(
result = await agent_loop.step(
messages,
max_steps=max_steps,
run_id=job_id,
use_assistant_message=use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=include_return_message_types,
@@ -904,6 +908,7 @@ async def process_message_background(
input_messages=messages,
stream_steps=False,
stream_tokens=False,
metadata={"job_id": job_id},
# Support for AssistantMessage
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
@@ -936,9 +941,8 @@ async def process_message_background(
async def send_message_async(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
request: LettaAsyncRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"),
callback_url: Optional[str] = Query(None, description="Optional callback URL to POST to when the job completes"),
):
"""
Asynchronously process a user message and return a run object.
@@ -951,7 +955,7 @@ async def send_message_async(
run = Run(
user_id=actor.id,
status=JobStatus.created,
callback_url=callback_url,
callback_url=request.callback_url,
metadata={
"job_type": "send_message_async",
"agent_id": agent_id,
@@ -960,6 +964,7 @@ async def send_message_async(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
),
)
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
@@ -1036,6 +1041,7 @@ async def summarize_agent_conversation(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,

View File

@@ -92,7 +92,7 @@ async def list_run_messages(
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
order: str = Query(
"desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
"asc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
),
role: Optional[MessageRole] = Query(None, description="Filter by role"),
):

View File

@@ -1355,6 +1355,7 @@ class SyncServer(Server):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=actor,
step_manager=self.step_manager,
@@ -1996,6 +1997,7 @@ class SyncServer(Server):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=actor,
sandbox_env_vars=tool_env_vars,

View File

@@ -342,6 +342,33 @@ class JobManager:
session.add(job_message)
session.commit()
@enforce_types
@trace_method
async def add_messages_to_job_async(self, job_id: str, message_ids: List[str], actor: PydanticUser) -> None:
"""
Associate a message with a job by creating a JobMessage record.
Each message can only be associated with one job.
Args:
job_id: The ID of the job
message_id: The ID of the message to associate
actor: The user making the request
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
if not message_ids:
return
async with db_registry.async_session() as session:
# First verify job exists and user has access
await self._verify_job_access_async(session, job_id, actor, access=["write"])
# Create new JobMessage associations
job_messages = [JobMessage(job_id=job_id, message_id=message_id) for message_id in message_ids]
session.add_all(job_messages)
await session.commit()
@enforce_types
@trace_method
def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics:
@@ -463,14 +490,19 @@ class JobManager:
)
request_config = self._get_run_request_config(run_id)
print("request_config", request_config)
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
reverse=not ascending,
)
if request_config["include_return_message_types"]:
messages = [msg for msg in messages if msg.message_type in request_config["include_return_message_types"]]
return messages
@enforce_types

View File

@@ -101,6 +101,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=self.actor,
)

View File

@@ -15,6 +15,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
@@ -49,6 +50,7 @@ class ToolExecutorFactory:
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager,
actor: User,
) -> ToolExecutor:
@@ -58,6 +60,7 @@ class ToolExecutorFactory:
message_manager=message_manager,
agent_manager=agent_manager,
block_manager=block_manager,
job_manager=job_manager,
passage_manager=passage_manager,
actor=actor,
)
@@ -71,6 +74,7 @@ class ToolExecutionManager:
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager,
actor: User,
agent_state: Optional[AgentState] = None,
@@ -80,6 +84,7 @@ class ToolExecutionManager:
self.message_manager = message_manager
self.agent_manager = agent_manager
self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager
self.agent_state = agent_state
self.logger = get_logger(__name__)
@@ -101,6 +106,7 @@ class ToolExecutionManager:
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
actor=self.actor,
)

View File

@@ -8,6 +8,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
@@ -20,12 +21,14 @@ class ToolExecutor(ABC):
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager,
actor: User,
):
self.message_manager = message_manager
self.agent_manager = agent_manager
self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager
self.actor = actor

View File

@@ -71,6 +71,7 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
agent_state=agent_state,
actor=default_user,

View File

@@ -19,6 +19,7 @@ from letta_client.types import (
Base64Image,
HiddenReasoningMessage,
ImageContent,
LettaMessageUnion,
LettaStopReason,
LettaUsageStatistics,
ReasoningMessage,
@@ -351,20 +352,24 @@ def accumulate_chunks(chunks: List[Any]) -> List[Any]:
return [m for m in messages if m is not None]
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
"""
Asserts that a list of message dictionaries contains the expected types and statuses.
def cast_message_dict_to_messages(messages: List[Dict[str, Any]]) -> List[LettaMessageUnion]:
def cast_message(message: Dict[str, Any]) -> LettaMessageUnion:
if message["message_type"] == "reasoning_message":
return ReasoningMessage(**message)
elif message["message_type"] == "assistant_message":
return AssistantMessage(**message)
elif message["message_type"] == "tool_call_message":
return ToolCallMessage(**message)
elif message["message_type"] == "tool_return_message":
return ToolReturnMessage(**message)
elif message["message_type"] == "user_message":
return UserMessage(**message)
elif message["message_type"] == "hidden_reasoning_message":
return HiddenReasoningMessage(**message)
else:
raise ValueError(f"Unknown message type: {message['message_type']}")
Expected order:
1. reasoning_message
2. tool_call_message
3. tool_return_message (with status 'success')
4. reasoning_message
5. assistant_message
"""
assert isinstance(messages, list)
assert messages[0]["message_type"] == "reasoning_message"
assert messages[1]["message_type"] == "assistant_message"
return [cast_message(message) for message in messages]
# ------------------------------
@@ -870,6 +875,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
if run.status == "completed":
return run
if run.status == "failed":
print(run)
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
@@ -891,6 +897,7 @@ def test_async_greeting_with_assistant_message(
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
@@ -902,8 +909,86 @@ def test_async_greeting_with_assistant_message(
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"]
assert_tool_response_dict_messages(messages)
messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
messages = client.runs.messages.list(run_id=run.id)
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_async_greeting_without_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
run = wait_for_run_completion(client, run.id)
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_without_assistant_message_response(messages, llm_config=llm_config)
messages = client.runs.messages.list(run_id=run.id)
assert_greeting_without_assistant_message_response(messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_async_tool_call(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
run = wait_for_run_completion(client, run.id)
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = cast_message_dict_to_messages(result["messages"])
assert_tool_call_response(messages, llm_config=llm_config)
messages = client.runs.messages.list(run_id=run.id)
assert_tool_call_response(messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
class CallbackServer:
@@ -1021,8 +1106,9 @@ def test_async_greeting_with_callback_url(
# Validate job completed successfully
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"]
assert_tool_response_dict_messages(messages)
messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
# Validate callback was received
assert server.wait_for_callback(timeout=15), "Callback was not received within timeout"
@@ -1084,8 +1170,9 @@ def test_async_callback_failure_scenarios(
# Validate job completed successfully
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"]
assert_tool_response_dict_messages(messages)
messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
# Job should be marked as completed even if callback failed
assert run.status == "completed", f"Expected status 'completed', got {run.status}"

View File

@@ -110,6 +110,7 @@ async def test_provider_trace_experimental_step(message, agent_state, default_us
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
job_manager=JobManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=TelemetryManager(),
@@ -134,6 +135,7 @@ async def test_provider_trace_experimental_step_stream(message, agent_state, def
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
job_manager=JobManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=TelemetryManager(),
@@ -189,6 +191,7 @@ async def test_noop_provider_trace(message, agent_state, default_user, event_loo
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
job_manager=JobManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=NoopTelemetryManager(),

View File

@@ -637,7 +637,7 @@ def test_many_blocks(client: LettaSDKClient):
# cases: steam, async, token stream, sync
@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync"])
@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync", "async"])
def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, message_create: str):
"""Test that the include_return_message_types parameter works"""