From 99e112e486661484e0b4a6fe3ed3ba10d2f45fbb Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 19 Jun 2025 13:51:51 -0700 Subject: [PATCH] feat: make create_async route consistent with other message routes (#2877) --- letta/agent.py | 11 +- letta/agents/base_agent.py | 4 +- letta/agents/letta_agent.py | 23 +++- letta/groups/sleeptime_multi_agent_v2.py | 6 + letta/schemas/job.py | 6 +- letta/schemas/letta_request.py | 4 + letta/server/rest_api/routers/v1/agents.py | 14 +- letta/server/rest_api/routers/v1/runs.py | 2 +- letta/server/server.py | 2 + letta/services/job_manager.py | 32 +++++ .../multi_agent_tool_executor.py | 1 + .../tool_executor/tool_execution_manager.py | 6 + .../tool_executor/tool_executor_base.py | 3 + tests/integration_test_composio.py | 1 + tests/integration_test_send_message.py | 125 +++++++++++++++--- tests/test_provider_trace.py | 3 + tests/test_sdk_client.py | 2 +- 17 files changed, 209 insertions(+), 36 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index ef8f69a9..ea009f2f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1000,11 +1000,12 @@ class Agent(BaseAgent): ) if job_id: for message in all_new_messages: - self.job_manager.add_message_to_job( - job_id=job_id, - message_id=message.id, - actor=self.user, - ) + if message.role != "user": + self.job_manager.add_message_to_job( + job_id=job_id, + message_id=message.id, + actor=self.user, + ) return AgentStepResponse( messages=all_new_messages, diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index cf903c79..bf451329 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -50,7 +50,9 @@ class BaseAgent(ABC): self.logger = get_logger(agent_id) @abstractmethod - async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse: + async def step( + self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None + ) -> LettaResponse: """ Main execution loop for the agent. """ diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 4fc949c0..d774f110 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -43,6 +43,7 @@ from letta.server.rest_api.utils import create_letta_messages_from_llm_response from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema +from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.step_manager import NoopStepManager, StepManager @@ -66,6 +67,7 @@ class LettaAgent(BaseAgent): message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, + job_manager: JobManager, passage_manager: PassageManager, actor: User, step_manager: StepManager = NoopStepManager(), @@ -81,6 +83,7 @@ class LettaAgent(BaseAgent): # TODO: Make this more general, factorable # Summarizer settings self.block_manager = block_manager + self.job_manager = job_manager self.passage_manager = passage_manager self.step_manager = step_manager self.telemetry_manager = telemetry_manager @@ -120,6 +123,7 @@ class LettaAgent(BaseAgent): self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, + run_id: Optional[str] = None, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -131,6 +135,7 @@ class LettaAgent(BaseAgent): agent_state=agent_state, input_messages=input_messages, max_steps=max_steps, + run_id=run_id, request_start_timestamp_ns=request_start_timestamp_ns, ) return _create_letta_response( @@ -193,7 +198,6 @@ class LettaAgent(BaseAgent): response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) # update usage - # TODO: add run_id usage.step_count += 1 usage.completion_tokens += response.usage.completion_tokens usage.prompt_tokens += response.usage.prompt_tokens @@ -302,6 +306,7 @@ class LettaAgent(BaseAgent): agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, + run_id: Optional[str] = None, request_start_timestamp_ns: Optional[int] = None, ) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]: """ @@ -345,11 +350,11 @@ class LettaAgent(BaseAgent): response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - # TODO: add run_id usage.step_count += 1 usage.completion_tokens += response.usage.completion_tokens usage.prompt_tokens += response.usage.prompt_tokens usage.total_tokens += response.usage.total_tokens + usage.run_ids = [run_id] if run_id else None MetricRegistry().message_output_tokens.record( response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) ) @@ -385,6 +390,7 @@ class LettaAgent(BaseAgent): initial_messages=initial_messages, agent_step_span=agent_step_span, is_final_step=(i == max_steps - 1), + run_id=run_id, ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -916,6 +922,7 @@ class LettaAgent(BaseAgent): initial_messages: Optional[List[Message]] = None, agent_step_span: Optional["Span"] = None, is_final_step: Optional[bool] = None, + run_id: Optional[str] = None, ) -> Tuple[List[Message], bool, Optional[LettaStopReason]]: """ Now that streaming is done, handle the final AI response. @@ -1027,7 +1034,7 @@ class LettaAgent(BaseAgent): # 5a. Persist Steps to DB # Following agent loop to persist this before messages - # TODO (cliandy): determine what should match old loop w/provider_id, job_id + # TODO (cliandy): determine what should match old loop w/provider_id # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same. logged_step = await self.step_manager.log_step_async( actor=self.actor, @@ -1039,7 +1046,7 @@ class LettaAgent(BaseAgent): context_window_limit=agent_state.llm_config.context_window, usage=usage, provider_id=None, - job_id=None, + job_id=run_id, step_id=step_id, ) @@ -1065,6 +1072,13 @@ class LettaAgent(BaseAgent): ) self.last_function_response = function_response + if run_id: + await self.job_manager.add_messages_to_job_async( + job_id=run_id, + message_ids=[message.id for message in persisted_messages if message.role != "user"], + actor=self.actor, + ) + return persisted_messages, continue_stepping, stop_reason @trace_method @@ -1102,6 +1116,7 @@ class LettaAgent(BaseAgent): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, sandbox_env_vars=sandbox_env_vars, actor=self.actor, diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 97c9216b..78d587d9 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -63,6 +63,7 @@ class SleeptimeMultiAgentV2(BaseAgent): self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, + run_id: Optional[str] = None, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -83,6 +84,7 @@ class SleeptimeMultiAgentV2(BaseAgent): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=self.actor, step_manager=self.step_manager, @@ -92,6 +94,7 @@ class SleeptimeMultiAgentV2(BaseAgent): response = await foreground_agent.step( input_messages=new_messages, max_steps=max_steps, + run_id=run_id, use_assistant_message=use_assistant_message, include_return_message_types=include_return_message_types, ) @@ -170,6 +173,7 @@ class SleeptimeMultiAgentV2(BaseAgent): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=self.actor, step_manager=self.step_manager, @@ -283,6 +287,7 @@ class SleeptimeMultiAgentV2(BaseAgent): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=self.actor, step_manager=self.step_manager, @@ -296,6 +301,7 @@ class SleeptimeMultiAgentV2(BaseAgent): result = await sleeptime_agent.step( input_messages=sleeptime_agent_messages, use_assistant_message=use_assistant_message, + run_id=run_id, ) # Update job status diff --git a/letta/schemas/job.py b/letta/schemas/job.py index eb6c8537..4f2b09c4 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from typing import List, Optional from pydantic import BaseModel, Field @@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.orm.enums import JobType from letta.schemas.enums import JobStatus from letta.schemas.letta_base import OrmMetadataBase +from letta.schemas.letta_message import MessageType class JobBase(OrmMetadataBase): @@ -94,3 +95,6 @@ class LettaRequestConfig(BaseModel): default=DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument in the designated message tool.", ) + include_return_message_types: Optional[List[MessageType]] = Field( + default=None, description="Only return specified message types in the response. If `None` (default) returns all messages." + ) diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 222de433..fdea2918 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -39,6 +39,10 @@ class LettaStreamingRequest(LettaRequest): ) +class LettaAsyncRequest(LettaRequest): + callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes") + + class LettaBatchRequest(LettaRequest): agent_id: str = Field(..., description="The ID of the agent to send this batch request for") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index f1fef585..68bbc8ef 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -25,7 +25,7 @@ from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType -from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest +from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory from letta.schemas.message import MessageCreate @@ -707,6 +707,7 @@ async def send_message( message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, + job_manager=server.job_manager, passage_manager=server.passage_manager, actor=actor, step_manager=server.step_manager, @@ -793,6 +794,7 @@ async def send_message_streaming( message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, + job_manager=server.job_manager, passage_manager=server.passage_manager, actor=actor, step_manager=server.step_manager, @@ -884,6 +886,7 @@ async def process_message_background( message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, + job_manager=server.job_manager, passage_manager=server.passage_manager, actor=actor, step_manager=server.step_manager, @@ -893,6 +896,7 @@ async def process_message_background( result = await agent_loop.step( messages, max_steps=max_steps, + run_id=job_id, use_assistant_message=use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, include_return_message_types=include_return_message_types, @@ -904,6 +908,7 @@ async def process_message_background( input_messages=messages, stream_steps=False, stream_tokens=False, + metadata={"job_id": job_id}, # Support for AssistantMessage use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, @@ -936,9 +941,8 @@ async def process_message_background( async def send_message_async( agent_id: str, server: SyncServer = Depends(get_letta_server), - request: LettaRequest = Body(...), + request: LettaAsyncRequest = Body(...), actor_id: Optional[str] = Header(None, alias="user_id"), - callback_url: Optional[str] = Query(None, description="Optional callback URL to POST to when the job completes"), ): """ Asynchronously process a user message and return a run object. @@ -951,7 +955,7 @@ async def send_message_async( run = Run( user_id=actor.id, status=JobStatus.created, - callback_url=callback_url, + callback_url=request.callback_url, metadata={ "job_type": "send_message_async", "agent_id": agent_id, @@ -960,6 +964,7 @@ async def send_message_async( use_assistant_message=request.use_assistant_message, assistant_message_tool_name=request.assistant_message_tool_name, assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, ), ) run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor) @@ -1036,6 +1041,7 @@ async def summarize_agent_conversation( message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, + job_manager=server.job_manager, passage_manager=server.passage_manager, actor=actor, step_manager=server.step_manager, diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index ac0476e4..5974bc28 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -92,7 +92,7 @@ async def list_run_messages( after: Optional[str] = Query(None, description="Cursor for pagination"), limit: Optional[int] = Query(100, description="Maximum number of messages to return"), order: str = Query( - "desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order." + "asc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order." ), role: Optional[MessageRole] = Query(None, description="Filter by role"), ): diff --git a/letta/server/server.py b/letta/server/server.py index 4f5e766f..242573bc 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1355,6 +1355,7 @@ class SyncServer(Server): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=actor, step_manager=self.step_manager, @@ -1996,6 +1997,7 @@ class SyncServer(Server): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=actor, sandbox_env_vars=tool_env_vars, diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index a5661698..c7201fc5 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -342,6 +342,33 @@ class JobManager: session.add(job_message) session.commit() + @enforce_types + @trace_method + async def add_messages_to_job_async(self, job_id: str, message_ids: List[str], actor: PydanticUser) -> None: + """ + Associate a message with a job by creating a JobMessage record. + Each message can only be associated with one job. + + Args: + job_id: The ID of the job + message_id: The ID of the message to associate + actor: The user making the request + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + if not message_ids: + return + + async with db_registry.async_session() as session: + # First verify job exists and user has access + await self._verify_job_access_async(session, job_id, actor, access=["write"]) + + # Create new JobMessage associations + job_messages = [JobMessage(job_id=job_id, message_id=message_id) for message_id in message_ids] + session.add_all(job_messages) + await session.commit() + @enforce_types @trace_method def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics: @@ -463,14 +490,19 @@ class JobManager: ) request_config = self._get_run_request_config(run_id) + print("request_config", request_config) messages = PydanticMessage.to_letta_messages_from_list( messages=messages, use_assistant_message=request_config["use_assistant_message"], assistant_message_tool_name=request_config["assistant_message_tool_name"], assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"], + reverse=not ascending, ) + if request_config["include_return_message_types"]: + messages = [msg for msg in messages if msg.message_type in request_config["include_return_message_types"]] + return messages @enforce_types diff --git a/letta/services/tool_executor/multi_agent_tool_executor.py b/letta/services/tool_executor/multi_agent_tool_executor.py index 02accc2e..e4362697 100644 --- a/letta/services/tool_executor/multi_agent_tool_executor.py +++ b/letta/services/tool_executor/multi_agent_tool_executor.py @@ -101,6 +101,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=self.actor, ) diff --git a/letta/services/tool_executor/tool_execution_manager.py b/letta/services/tool_executor/tool_execution_manager.py index 2b4dad6a..f4f94ae5 100644 --- a/letta/services/tool_executor/tool_execution_manager.py +++ b/letta/services/tool_executor/tool_execution_manager.py @@ -15,6 +15,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor @@ -49,6 +50,7 @@ class ToolExecutorFactory: message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, + job_manager: JobManager, passage_manager: PassageManager, actor: User, ) -> ToolExecutor: @@ -58,6 +60,7 @@ class ToolExecutorFactory: message_manager=message_manager, agent_manager=agent_manager, block_manager=block_manager, + job_manager=job_manager, passage_manager=passage_manager, actor=actor, ) @@ -71,6 +74,7 @@ class ToolExecutionManager: message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, + job_manager: JobManager, passage_manager: PassageManager, actor: User, agent_state: Optional[AgentState] = None, @@ -80,6 +84,7 @@ class ToolExecutionManager: self.message_manager = message_manager self.agent_manager = agent_manager self.block_manager = block_manager + self.job_manager = job_manager self.passage_manager = passage_manager self.agent_state = agent_state self.logger = get_logger(__name__) @@ -101,6 +106,7 @@ class ToolExecutionManager: message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, + job_manager=self.job_manager, passage_manager=self.passage_manager, actor=self.actor, ) diff --git a/letta/services/tool_executor/tool_executor_base.py b/letta/services/tool_executor/tool_executor_base.py index a8a7ccb2..452ce681 100644 --- a/letta/services/tool_executor/tool_executor_base.py +++ b/letta/services/tool_executor/tool_executor_base.py @@ -8,6 +8,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager @@ -20,12 +21,14 @@ class ToolExecutor(ABC): message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, + job_manager: JobManager, passage_manager: PassageManager, actor: User, ): self.message_manager = message_manager self.agent_manager = agent_manager self.block_manager = block_manager + self.job_manager = job_manager self.passage_manager = passage_manager self.actor = actor diff --git a/tests/integration_test_composio.py b/tests/integration_test_composio.py index 6bd4272a..57a15e0a 100644 --- a/tests/integration_test_composio.py +++ b/tests/integration_test_composio.py @@ -71,6 +71,7 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_ message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, + job_manager=server.job_manager, passage_manager=server.passage_manager, agent_state=agent_state, actor=default_user, diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 16c9e745..76dcd583 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -19,6 +19,7 @@ from letta_client.types import ( Base64Image, HiddenReasoningMessage, ImageContent, + LettaMessageUnion, LettaStopReason, LettaUsageStatistics, ReasoningMessage, @@ -351,20 +352,24 @@ def accumulate_chunks(chunks: List[Any]) -> List[Any]: return [m for m in messages if m is not None] -def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: - """ - Asserts that a list of message dictionaries contains the expected types and statuses. +def cast_message_dict_to_messages(messages: List[Dict[str, Any]]) -> List[LettaMessageUnion]: + def cast_message(message: Dict[str, Any]) -> LettaMessageUnion: + if message["message_type"] == "reasoning_message": + return ReasoningMessage(**message) + elif message["message_type"] == "assistant_message": + return AssistantMessage(**message) + elif message["message_type"] == "tool_call_message": + return ToolCallMessage(**message) + elif message["message_type"] == "tool_return_message": + return ToolReturnMessage(**message) + elif message["message_type"] == "user_message": + return UserMessage(**message) + elif message["message_type"] == "hidden_reasoning_message": + return HiddenReasoningMessage(**message) + else: + raise ValueError(f"Unknown message type: {message['message_type']}") - Expected order: - 1. reasoning_message - 2. tool_call_message - 3. tool_return_message (with status 'success') - 4. reasoning_message - 5. assistant_message - """ - assert isinstance(messages, list) - assert messages[0]["message_type"] == "reasoning_message" - assert messages[1]["message_type"] == "assistant_message" + return [cast_message(message) for message in messages] # ------------------------------ @@ -870,6 +875,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i if run.status == "completed": return run if run.status == "failed": + print(run) raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") if time.time() - start > timeout: raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") @@ -891,6 +897,7 @@ def test_async_greeting_with_assistant_message( Tests sending a message as an asynchronous job using the synchronous client. Waits for job completion and asserts that the result messages are as expected. """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) run = client.agents.messages.create_async( @@ -902,8 +909,86 @@ def test_async_greeting_with_assistant_message( result = run.metadata.get("result") assert result is not None, "Run metadata missing 'result' key" - messages = result["messages"] - assert_tool_response_dict_messages(messages) + messages = cast_message_dict_to_messages(result["messages"]) + assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) + + messages = client.runs.messages.list(run_id=run.id) + assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_async_greeting_without_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message as an asynchronous job using the synchronous client. + Waits for job completion and asserts that the result messages are as expected. + """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + run = client.agents.messages.create_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + use_assistant_message=False, + ) + run = wait_for_run_completion(client, run.id) + + result = run.metadata.get("result") + assert result is not None, "Run metadata missing 'result' key" + + messages = cast_message_dict_to_messages(result["messages"]) + assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) + + messages = client.runs.messages.list(run_id=run.id) + assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) + assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_async_tool_call( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message as an asynchronous job using the synchronous client. + Waits for job completion and asserts that the result messages are as expected. + """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + run = client.agents.messages.create_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + ) + run = wait_for_run_completion(client, run.id) + + result = run.metadata.get("result") + assert result is not None, "Run metadata missing 'result' key" + + messages = cast_message_dict_to_messages(result["messages"]) + assert_tool_call_response(messages, llm_config=llm_config) + + messages = client.runs.messages.list(run_id=run.id) + assert_tool_call_response(messages, llm_config=llm_config) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) class CallbackServer: @@ -1021,8 +1106,9 @@ def test_async_greeting_with_callback_url( # Validate job completed successfully result = run.metadata.get("result") assert result is not None, "Run metadata missing 'result' key" - messages = result["messages"] - assert_tool_response_dict_messages(messages) + + messages = cast_message_dict_to_messages(result["messages"]) + assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) # Validate callback was received assert server.wait_for_callback(timeout=15), "Callback was not received within timeout" @@ -1084,8 +1170,9 @@ def test_async_callback_failure_scenarios( # Validate job completed successfully result = run.metadata.get("result") assert result is not None, "Run metadata missing 'result' key" - messages = result["messages"] - assert_tool_response_dict_messages(messages) + + messages = cast_message_dict_to_messages(result["messages"]) + assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) # Job should be marked as completed even if callback failed assert run.status == "completed", f"Expected status 'completed', got {run.status}" diff --git a/tests/test_provider_trace.py b/tests/test_provider_trace.py index 43e13a34..574c4a1c 100644 --- a/tests/test_provider_trace.py +++ b/tests/test_provider_trace.py @@ -110,6 +110,7 @@ async def test_provider_trace_experimental_step(message, agent_state, default_us message_manager=MessageManager(), agent_manager=AgentManager(), block_manager=BlockManager(), + job_manager=JobManager(), passage_manager=PassageManager(), step_manager=StepManager(), telemetry_manager=TelemetryManager(), @@ -134,6 +135,7 @@ async def test_provider_trace_experimental_step_stream(message, agent_state, def message_manager=MessageManager(), agent_manager=AgentManager(), block_manager=BlockManager(), + job_manager=JobManager(), passage_manager=PassageManager(), step_manager=StepManager(), telemetry_manager=TelemetryManager(), @@ -189,6 +191,7 @@ async def test_noop_provider_trace(message, agent_state, default_user, event_loo message_manager=MessageManager(), agent_manager=AgentManager(), block_manager=BlockManager(), + job_manager=JobManager(), passage_manager=PassageManager(), step_manager=StepManager(), telemetry_manager=NoopTelemetryManager(), diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 0e81f40b..1e56adc9 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -637,7 +637,7 @@ def test_many_blocks(client: LettaSDKClient): # cases: steam, async, token stream, sync -@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync"]) +@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync", "async"]) def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, message_create: str): """Test that the include_return_message_types parameter works"""