feat: make create_async route consistent with other message routes (#2877)

This commit is contained in:
cthomas
2025-06-19 13:51:51 -07:00
committed by GitHub
parent 56493de971
commit 99e112e486
17 changed files with 209 additions and 36 deletions

View File

@@ -1000,11 +1000,12 @@ class Agent(BaseAgent):
) )
if job_id: if job_id:
for message in all_new_messages: for message in all_new_messages:
self.job_manager.add_message_to_job( if message.role != "user":
job_id=job_id, self.job_manager.add_message_to_job(
message_id=message.id, job_id=job_id,
actor=self.user, message_id=message.id,
) actor=self.user,
)
return AgentStepResponse( return AgentStepResponse(
messages=all_new_messages, messages=all_new_messages,

View File

@@ -50,7 +50,9 @@ class BaseAgent(ABC):
self.logger = get_logger(agent_id) self.logger = get_logger(agent_id)
@abstractmethod @abstractmethod
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse: async def step(
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None
) -> LettaResponse:
""" """
Main execution loop for the agent. Main execution loop for the agent.
""" """

View File

@@ -43,6 +43,7 @@ from letta.server.rest_api.utils import create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager from letta.services.block_manager import BlockManager
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager from letta.services.step_manager import NoopStepManager, StepManager
@@ -66,6 +67,7 @@ class LettaAgent(BaseAgent):
message_manager: MessageManager, message_manager: MessageManager,
agent_manager: AgentManager, agent_manager: AgentManager,
block_manager: BlockManager, block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager, passage_manager: PassageManager,
actor: User, actor: User,
step_manager: StepManager = NoopStepManager(), step_manager: StepManager = NoopStepManager(),
@@ -81,6 +83,7 @@ class LettaAgent(BaseAgent):
# TODO: Make this more general, factorable # TODO: Make this more general, factorable
# Summarizer settings # Summarizer settings
self.block_manager = block_manager self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager self.passage_manager = passage_manager
self.step_manager = step_manager self.step_manager = step_manager
self.telemetry_manager = telemetry_manager self.telemetry_manager = telemetry_manager
@@ -120,6 +123,7 @@ class LettaAgent(BaseAgent):
self, self,
input_messages: List[MessageCreate], input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS, max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
use_assistant_message: bool = True, use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None, request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None, include_return_message_types: Optional[List[MessageType]] = None,
@@ -131,6 +135,7 @@ class LettaAgent(BaseAgent):
agent_state=agent_state, agent_state=agent_state,
input_messages=input_messages, input_messages=input_messages,
max_steps=max_steps, max_steps=max_steps,
run_id=run_id,
request_start_timestamp_ns=request_start_timestamp_ns, request_start_timestamp_ns=request_start_timestamp_ns,
) )
return _create_letta_response( return _create_letta_response(
@@ -193,7 +198,6 @@ class LettaAgent(BaseAgent):
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage # update usage
# TODO: add run_id
usage.step_count += 1 usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens usage.prompt_tokens += response.usage.prompt_tokens
@@ -302,6 +306,7 @@ class LettaAgent(BaseAgent):
agent_state: AgentState, agent_state: AgentState,
input_messages: List[MessageCreate], input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS, max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
request_start_timestamp_ns: Optional[int] = None, request_start_timestamp_ns: Optional[int] = None,
) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]: ) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]:
""" """
@@ -345,11 +350,11 @@ class LettaAgent(BaseAgent):
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# TODO: add run_id
usage.step_count += 1 usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens usage.total_tokens += response.usage.total_tokens
usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record( MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
) )
@@ -385,6 +390,7 @@ class LettaAgent(BaseAgent):
initial_messages=initial_messages, initial_messages=initial_messages,
agent_step_span=agent_step_span, agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1), is_final_step=(i == max_steps - 1),
run_id=run_id,
) )
self.response_messages.extend(persisted_messages) self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages)
@@ -916,6 +922,7 @@ class LettaAgent(BaseAgent):
initial_messages: Optional[List[Message]] = None, initial_messages: Optional[List[Message]] = None,
agent_step_span: Optional["Span"] = None, agent_step_span: Optional["Span"] = None,
is_final_step: Optional[bool] = None, is_final_step: Optional[bool] = None,
run_id: Optional[str] = None,
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]: ) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
""" """
Now that streaming is done, handle the final AI response. Now that streaming is done, handle the final AI response.
@@ -1027,7 +1034,7 @@ class LettaAgent(BaseAgent):
# 5a. Persist Steps to DB # 5a. Persist Steps to DB
# Following agent loop to persist this before messages # Following agent loop to persist this before messages
# TODO (cliandy): determine what should match old loop w/provider_id, job_id # TODO (cliandy): determine what should match old loop w/provider_id
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same. # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
logged_step = await self.step_manager.log_step_async( logged_step = await self.step_manager.log_step_async(
actor=self.actor, actor=self.actor,
@@ -1039,7 +1046,7 @@ class LettaAgent(BaseAgent):
context_window_limit=agent_state.llm_config.context_window, context_window_limit=agent_state.llm_config.context_window,
usage=usage, usage=usage,
provider_id=None, provider_id=None,
job_id=None, job_id=run_id,
step_id=step_id, step_id=step_id,
) )
@@ -1065,6 +1072,13 @@ class LettaAgent(BaseAgent):
) )
self.last_function_response = function_response self.last_function_response = function_response
if run_id:
await self.job_manager.add_messages_to_job_async(
job_id=run_id,
message_ids=[message.id for message in persisted_messages if message.role != "user"],
actor=self.actor,
)
return persisted_messages, continue_stepping, stop_reason return persisted_messages, continue_stepping, stop_reason
@trace_method @trace_method
@@ -1102,6 +1116,7 @@ class LettaAgent(BaseAgent):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
sandbox_env_vars=sandbox_env_vars, sandbox_env_vars=sandbox_env_vars,
actor=self.actor, actor=self.actor,

View File

@@ -63,6 +63,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
self, self,
input_messages: List[MessageCreate], input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS, max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
use_assistant_message: bool = True, use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None, request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None, include_return_message_types: Optional[List[MessageType]] = None,
@@ -83,6 +84,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=self.actor, actor=self.actor,
step_manager=self.step_manager, step_manager=self.step_manager,
@@ -92,6 +94,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
response = await foreground_agent.step( response = await foreground_agent.step(
input_messages=new_messages, input_messages=new_messages,
max_steps=max_steps, max_steps=max_steps,
run_id=run_id,
use_assistant_message=use_assistant_message, use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types, include_return_message_types=include_return_message_types,
) )
@@ -170,6 +173,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=self.actor, actor=self.actor,
step_manager=self.step_manager, step_manager=self.step_manager,
@@ -283,6 +287,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=self.actor, actor=self.actor,
step_manager=self.step_manager, step_manager=self.step_manager,
@@ -296,6 +301,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
result = await sleeptime_agent.step( result = await sleeptime_agent.step(
input_messages=sleeptime_agent_messages, input_messages=sleeptime_agent_messages,
use_assistant_message=use_assistant_message, use_assistant_message=use_assistant_message,
run_id=run_id,
) )
# Update job status # Update job status

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.orm.enums import JobType from letta.orm.enums import JobType
from letta.schemas.enums import JobStatus from letta.schemas.enums import JobStatus
from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import MessageType
class JobBase(OrmMetadataBase): class JobBase(OrmMetadataBase):
@@ -94,3 +95,6 @@ class LettaRequestConfig(BaseModel):
default=DEFAULT_MESSAGE_TOOL_KWARG, default=DEFAULT_MESSAGE_TOOL_KWARG,
description="The name of the message argument in the designated message tool.", description="The name of the message argument in the designated message tool.",
) )
include_return_message_types: Optional[List[MessageType]] = Field(
default=None, description="Only return specified message types in the response. If `None` (default) returns all messages."
)

View File

@@ -39,6 +39,10 @@ class LettaStreamingRequest(LettaRequest):
) )
class LettaAsyncRequest(LettaRequest):
callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes")
class LettaBatchRequest(LettaRequest): class LettaBatchRequest(LettaRequest):
agent_id: str = Field(..., description="The ID of the agent to send this batch request for") agent_id: str = Field(..., description="The ID of the agent to send this batch request for")

View File

@@ -25,7 +25,7 @@ from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import MessageCreate from letta.schemas.message import MessageCreate
@@ -707,6 +707,7 @@ async def send_message(
message_manager=server.message_manager, message_manager=server.message_manager,
agent_manager=server.agent_manager, agent_manager=server.agent_manager,
block_manager=server.block_manager, block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager, passage_manager=server.passage_manager,
actor=actor, actor=actor,
step_manager=server.step_manager, step_manager=server.step_manager,
@@ -793,6 +794,7 @@ async def send_message_streaming(
message_manager=server.message_manager, message_manager=server.message_manager,
agent_manager=server.agent_manager, agent_manager=server.agent_manager,
block_manager=server.block_manager, block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager, passage_manager=server.passage_manager,
actor=actor, actor=actor,
step_manager=server.step_manager, step_manager=server.step_manager,
@@ -884,6 +886,7 @@ async def process_message_background(
message_manager=server.message_manager, message_manager=server.message_manager,
agent_manager=server.agent_manager, agent_manager=server.agent_manager,
block_manager=server.block_manager, block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager, passage_manager=server.passage_manager,
actor=actor, actor=actor,
step_manager=server.step_manager, step_manager=server.step_manager,
@@ -893,6 +896,7 @@ async def process_message_background(
result = await agent_loop.step( result = await agent_loop.step(
messages, messages,
max_steps=max_steps, max_steps=max_steps,
run_id=job_id,
use_assistant_message=use_assistant_message, use_assistant_message=use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns, request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=include_return_message_types, include_return_message_types=include_return_message_types,
@@ -904,6 +908,7 @@ async def process_message_background(
input_messages=messages, input_messages=messages,
stream_steps=False, stream_steps=False,
stream_tokens=False, stream_tokens=False,
metadata={"job_id": job_id},
# Support for AssistantMessage # Support for AssistantMessage
use_assistant_message=use_assistant_message, use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_name=assistant_message_tool_name,
@@ -936,9 +941,8 @@ async def process_message_background(
async def send_message_async( async def send_message_async(
agent_id: str, agent_id: str,
server: SyncServer = Depends(get_letta_server), server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...), request: LettaAsyncRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), actor_id: Optional[str] = Header(None, alias="user_id"),
callback_url: Optional[str] = Query(None, description="Optional callback URL to POST to when the job completes"),
): ):
""" """
Asynchronously process a user message and return a run object. Asynchronously process a user message and return a run object.
@@ -951,7 +955,7 @@ async def send_message_async(
run = Run( run = Run(
user_id=actor.id, user_id=actor.id,
status=JobStatus.created, status=JobStatus.created,
callback_url=callback_url, callback_url=request.callback_url,
metadata={ metadata={
"job_type": "send_message_async", "job_type": "send_message_async",
"agent_id": agent_id, "agent_id": agent_id,
@@ -960,6 +964,7 @@ async def send_message_async(
use_assistant_message=request.use_assistant_message, use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name, assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
), ),
) )
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor) run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
@@ -1036,6 +1041,7 @@ async def summarize_agent_conversation(
message_manager=server.message_manager, message_manager=server.message_manager,
agent_manager=server.agent_manager, agent_manager=server.agent_manager,
block_manager=server.block_manager, block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager, passage_manager=server.passage_manager,
actor=actor, actor=actor,
step_manager=server.step_manager, step_manager=server.step_manager,

View File

@@ -92,7 +92,7 @@ async def list_run_messages(
after: Optional[str] = Query(None, description="Cursor for pagination"), after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"), limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
order: str = Query( order: str = Query(
"desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order." "asc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
), ),
role: Optional[MessageRole] = Query(None, description="Filter by role"), role: Optional[MessageRole] = Query(None, description="Filter by role"),
): ):

View File

@@ -1355,6 +1355,7 @@ class SyncServer(Server):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=actor, actor=actor,
step_manager=self.step_manager, step_manager=self.step_manager,
@@ -1996,6 +1997,7 @@ class SyncServer(Server):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=actor, actor=actor,
sandbox_env_vars=tool_env_vars, sandbox_env_vars=tool_env_vars,

View File

@@ -342,6 +342,33 @@ class JobManager:
session.add(job_message) session.add(job_message)
session.commit() session.commit()
@enforce_types
@trace_method
async def add_messages_to_job_async(self, job_id: str, message_ids: List[str], actor: PydanticUser) -> None:
"""
Associate a message with a job by creating a JobMessage record.
Each message can only be associated with one job.
Args:
job_id: The ID of the job
message_id: The ID of the message to associate
actor: The user making the request
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
if not message_ids:
return
async with db_registry.async_session() as session:
# First verify job exists and user has access
await self._verify_job_access_async(session, job_id, actor, access=["write"])
# Create new JobMessage associations
job_messages = [JobMessage(job_id=job_id, message_id=message_id) for message_id in message_ids]
session.add_all(job_messages)
await session.commit()
@enforce_types @enforce_types
@trace_method @trace_method
def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics: def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics:
@@ -463,14 +490,19 @@ class JobManager:
) )
request_config = self._get_run_request_config(run_id) request_config = self._get_run_request_config(run_id)
print("request_config", request_config)
messages = PydanticMessage.to_letta_messages_from_list( messages = PydanticMessage.to_letta_messages_from_list(
messages=messages, messages=messages,
use_assistant_message=request_config["use_assistant_message"], use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"], assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"], assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
reverse=not ascending,
) )
if request_config["include_return_message_types"]:
messages = [msg for msg in messages if msg.message_type in request_config["include_return_message_types"]]
return messages return messages
@enforce_types @enforce_types

View File

@@ -101,6 +101,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=self.actor, actor=self.actor,
) )

View File

@@ -15,6 +15,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User from letta.schemas.user import User
from letta.services.agent_manager import AgentManager from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager from letta.services.block_manager import BlockManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
@@ -49,6 +50,7 @@ class ToolExecutorFactory:
message_manager: MessageManager, message_manager: MessageManager,
agent_manager: AgentManager, agent_manager: AgentManager,
block_manager: BlockManager, block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager, passage_manager: PassageManager,
actor: User, actor: User,
) -> ToolExecutor: ) -> ToolExecutor:
@@ -58,6 +60,7 @@ class ToolExecutorFactory:
message_manager=message_manager, message_manager=message_manager,
agent_manager=agent_manager, agent_manager=agent_manager,
block_manager=block_manager, block_manager=block_manager,
job_manager=job_manager,
passage_manager=passage_manager, passage_manager=passage_manager,
actor=actor, actor=actor,
) )
@@ -71,6 +74,7 @@ class ToolExecutionManager:
message_manager: MessageManager, message_manager: MessageManager,
agent_manager: AgentManager, agent_manager: AgentManager,
block_manager: BlockManager, block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager, passage_manager: PassageManager,
actor: User, actor: User,
agent_state: Optional[AgentState] = None, agent_state: Optional[AgentState] = None,
@@ -80,6 +84,7 @@ class ToolExecutionManager:
self.message_manager = message_manager self.message_manager = message_manager
self.agent_manager = agent_manager self.agent_manager = agent_manager
self.block_manager = block_manager self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager self.passage_manager = passage_manager
self.agent_state = agent_state self.agent_state = agent_state
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
@@ -101,6 +106,7 @@ class ToolExecutionManager:
message_manager=self.message_manager, message_manager=self.message_manager,
agent_manager=self.agent_manager, agent_manager=self.agent_manager,
block_manager=self.block_manager, block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager, passage_manager=self.passage_manager,
actor=self.actor, actor=self.actor,
) )

View File

@@ -8,6 +8,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User from letta.schemas.user import User
from letta.services.agent_manager import AgentManager from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager from letta.services.block_manager import BlockManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
@@ -20,12 +21,14 @@ class ToolExecutor(ABC):
message_manager: MessageManager, message_manager: MessageManager,
agent_manager: AgentManager, agent_manager: AgentManager,
block_manager: BlockManager, block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager, passage_manager: PassageManager,
actor: User, actor: User,
): ):
self.message_manager = message_manager self.message_manager = message_manager
self.agent_manager = agent_manager self.agent_manager = agent_manager
self.block_manager = block_manager self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager self.passage_manager = passage_manager
self.actor = actor self.actor = actor

View File

@@ -71,6 +71,7 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_
message_manager=server.message_manager, message_manager=server.message_manager,
agent_manager=server.agent_manager, agent_manager=server.agent_manager,
block_manager=server.block_manager, block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager, passage_manager=server.passage_manager,
agent_state=agent_state, agent_state=agent_state,
actor=default_user, actor=default_user,

View File

@@ -19,6 +19,7 @@ from letta_client.types import (
Base64Image, Base64Image,
HiddenReasoningMessage, HiddenReasoningMessage,
ImageContent, ImageContent,
LettaMessageUnion,
LettaStopReason, LettaStopReason,
LettaUsageStatistics, LettaUsageStatistics,
ReasoningMessage, ReasoningMessage,
@@ -351,20 +352,24 @@ def accumulate_chunks(chunks: List[Any]) -> List[Any]:
return [m for m in messages if m is not None] return [m for m in messages if m is not None]
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: def cast_message_dict_to_messages(messages: List[Dict[str, Any]]) -> List[LettaMessageUnion]:
""" def cast_message(message: Dict[str, Any]) -> LettaMessageUnion:
Asserts that a list of message dictionaries contains the expected types and statuses. if message["message_type"] == "reasoning_message":
return ReasoningMessage(**message)
elif message["message_type"] == "assistant_message":
return AssistantMessage(**message)
elif message["message_type"] == "tool_call_message":
return ToolCallMessage(**message)
elif message["message_type"] == "tool_return_message":
return ToolReturnMessage(**message)
elif message["message_type"] == "user_message":
return UserMessage(**message)
elif message["message_type"] == "hidden_reasoning_message":
return HiddenReasoningMessage(**message)
else:
raise ValueError(f"Unknown message type: {message['message_type']}")
Expected order: return [cast_message(message) for message in messages]
1. reasoning_message
2. tool_call_message
3. tool_return_message (with status 'success')
4. reasoning_message
5. assistant_message
"""
assert isinstance(messages, list)
assert messages[0]["message_type"] == "reasoning_message"
assert messages[1]["message_type"] == "assistant_message"
# ------------------------------ # ------------------------------
@@ -870,6 +875,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
if run.status == "completed": if run.status == "completed":
return run return run
if run.status == "failed": if run.status == "failed":
print(run)
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout: if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
@@ -891,6 +897,7 @@ def test_async_greeting_with_assistant_message(
Tests sending a message as an asynchronous job using the synchronous client. Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected. Waits for job completion and asserts that the result messages are as expected.
""" """
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async( run = client.agents.messages.create_async(
@@ -902,8 +909,86 @@ def test_async_greeting_with_assistant_message(
result = run.metadata.get("result") result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key" assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"] messages = cast_message_dict_to_messages(result["messages"])
assert_tool_response_dict_messages(messages) assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
messages = client.runs.messages.list(run_id=run.id)
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_async_greeting_without_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
run = wait_for_run_completion(client, run.id)
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_without_assistant_message_response(messages, llm_config=llm_config)
messages = client.runs.messages.list(run_id=run.id)
assert_greeting_without_assistant_message_response(messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_async_tool_call(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
run = wait_for_run_completion(client, run.id)
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = cast_message_dict_to_messages(result["messages"])
assert_tool_call_response(messages, llm_config=llm_config)
messages = client.runs.messages.list(run_id=run.id)
assert_tool_call_response(messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
class CallbackServer: class CallbackServer:
@@ -1021,8 +1106,9 @@ def test_async_greeting_with_callback_url(
# Validate job completed successfully # Validate job completed successfully
result = run.metadata.get("result") result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key" assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"]
assert_tool_response_dict_messages(messages) messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
# Validate callback was received # Validate callback was received
assert server.wait_for_callback(timeout=15), "Callback was not received within timeout" assert server.wait_for_callback(timeout=15), "Callback was not received within timeout"
@@ -1084,8 +1170,9 @@ def test_async_callback_failure_scenarios(
# Validate job completed successfully # Validate job completed successfully
result = run.metadata.get("result") result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key" assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"]
assert_tool_response_dict_messages(messages) messages = cast_message_dict_to_messages(result["messages"])
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
# Job should be marked as completed even if callback failed # Job should be marked as completed even if callback failed
assert run.status == "completed", f"Expected status 'completed', got {run.status}" assert run.status == "completed", f"Expected status 'completed', got {run.status}"

View File

@@ -110,6 +110,7 @@ async def test_provider_trace_experimental_step(message, agent_state, default_us
message_manager=MessageManager(), message_manager=MessageManager(),
agent_manager=AgentManager(), agent_manager=AgentManager(),
block_manager=BlockManager(), block_manager=BlockManager(),
job_manager=JobManager(),
passage_manager=PassageManager(), passage_manager=PassageManager(),
step_manager=StepManager(), step_manager=StepManager(),
telemetry_manager=TelemetryManager(), telemetry_manager=TelemetryManager(),
@@ -134,6 +135,7 @@ async def test_provider_trace_experimental_step_stream(message, agent_state, def
message_manager=MessageManager(), message_manager=MessageManager(),
agent_manager=AgentManager(), agent_manager=AgentManager(),
block_manager=BlockManager(), block_manager=BlockManager(),
job_manager=JobManager(),
passage_manager=PassageManager(), passage_manager=PassageManager(),
step_manager=StepManager(), step_manager=StepManager(),
telemetry_manager=TelemetryManager(), telemetry_manager=TelemetryManager(),
@@ -189,6 +191,7 @@ async def test_noop_provider_trace(message, agent_state, default_user, event_loo
message_manager=MessageManager(), message_manager=MessageManager(),
agent_manager=AgentManager(), agent_manager=AgentManager(),
block_manager=BlockManager(), block_manager=BlockManager(),
job_manager=JobManager(),
passage_manager=PassageManager(), passage_manager=PassageManager(),
step_manager=StepManager(), step_manager=StepManager(),
telemetry_manager=NoopTelemetryManager(), telemetry_manager=NoopTelemetryManager(),

View File

@@ -637,7 +637,7 @@ def test_many_blocks(client: LettaSDKClient):
# cases: steam, async, token stream, sync # cases: steam, async, token stream, sync
@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync"]) @pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync", "async"])
def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, message_create: str): def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, message_create: str):
"""Test that the include_return_message_types parameter works""" """Test that the include_return_message_types parameter works"""