feat: make create_async route consistent with other message routes (#2877)
This commit is contained in:
@@ -1000,11 +1000,12 @@ class Agent(BaseAgent):
|
||||
)
|
||||
if job_id:
|
||||
for message in all_new_messages:
|
||||
self.job_manager.add_message_to_job(
|
||||
job_id=job_id,
|
||||
message_id=message.id,
|
||||
actor=self.user,
|
||||
)
|
||||
if message.role != "user":
|
||||
self.job_manager.add_message_to_job(
|
||||
job_id=job_id,
|
||||
message_id=message.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
return AgentStepResponse(
|
||||
messages=all_new_messages,
|
||||
|
||||
@@ -50,7 +50,9 @@ class BaseAgent(ABC):
|
||||
self.logger = get_logger(agent_id)
|
||||
|
||||
@abstractmethod
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
|
||||
async def step(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Main execution loop for the agent.
|
||||
"""
|
||||
|
||||
@@ -43,6 +43,7 @@ from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
@@ -66,6 +67,7 @@ class LettaAgent(BaseAgent):
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
job_manager: JobManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
@@ -81,6 +83,7 @@ class LettaAgent(BaseAgent):
|
||||
# TODO: Make this more general, factorable
|
||||
# Summarizer settings
|
||||
self.block_manager = block_manager
|
||||
self.job_manager = job_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.step_manager = step_manager
|
||||
self.telemetry_manager = telemetry_manager
|
||||
@@ -120,6 +123,7 @@ class LettaAgent(BaseAgent):
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -131,6 +135,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_state=agent_state,
|
||||
input_messages=input_messages,
|
||||
max_steps=max_steps,
|
||||
run_id=run_id,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
return _create_letta_response(
|
||||
@@ -193,7 +198,6 @@ class LettaAgent(BaseAgent):
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
@@ -302,6 +306,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_state: AgentState,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]:
|
||||
"""
|
||||
@@ -345,11 +350,11 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
usage.run_ids = [run_id] if run_id else None
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
@@ -385,6 +390,7 @@ class LettaAgent(BaseAgent):
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
run_id=run_id,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -916,6 +922,7 @@ class LettaAgent(BaseAgent):
|
||||
initial_messages: Optional[List[Message]] = None,
|
||||
agent_step_span: Optional["Span"] = None,
|
||||
is_final_step: Optional[bool] = None,
|
||||
run_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
@@ -1027,7 +1034,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# 5a. Persist Steps to DB
|
||||
# Following agent loop to persist this before messages
|
||||
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
|
||||
# TODO (cliandy): determine what should match old loop w/provider_id
|
||||
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
|
||||
logged_step = await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
@@ -1039,7 +1046,7 @@ class LettaAgent(BaseAgent):
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=usage,
|
||||
provider_id=None,
|
||||
job_id=None,
|
||||
job_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
@@ -1065,6 +1072,13 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
self.last_function_response = function_response
|
||||
|
||||
if run_id:
|
||||
await self.job_manager.add_messages_to_job_async(
|
||||
job_id=run_id,
|
||||
message_ids=[message.id for message in persisted_messages if message.role != "user"],
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
@trace_method
|
||||
@@ -1102,6 +1116,7 @@ class LettaAgent(BaseAgent):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
actor=self.actor,
|
||||
|
||||
@@ -63,6 +63,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -83,6 +84,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
@@ -92,6 +94,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
response = await foreground_agent.step(
|
||||
input_messages=new_messages,
|
||||
max_steps=max_steps,
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
)
|
||||
@@ -170,6 +173,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
@@ -283,6 +287,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
@@ -296,6 +301,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
result = await sleeptime_agent.step(
|
||||
input_messages=sleeptime_agent_messages,
|
||||
use_assistant_message=use_assistant_message,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
# Update job status
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.orm.enums import JobType
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import MessageType
|
||||
|
||||
|
||||
class JobBase(OrmMetadataBase):
|
||||
@@ -94,3 +95,6 @@ class LettaRequestConfig(BaseModel):
|
||||
default=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
description="The name of the message argument in the designated message tool.",
|
||||
)
|
||||
include_return_message_types: Optional[List[MessageType]] = Field(
|
||||
default=None, description="Only return specified message types in the response. If `None` (default) returns all messages."
|
||||
)
|
||||
|
||||
@@ -39,6 +39,10 @@ class LettaStreamingRequest(LettaRequest):
|
||||
)
|
||||
|
||||
|
||||
class LettaAsyncRequest(LettaRequest):
|
||||
callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes")
|
||||
|
||||
|
||||
class LettaBatchRequest(LettaRequest):
|
||||
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.message import MessageCreate
|
||||
@@ -707,6 +707,7 @@ async def send_message(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
@@ -793,6 +794,7 @@ async def send_message_streaming(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
@@ -884,6 +886,7 @@ async def process_message_background(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
@@ -893,6 +896,7 @@ async def process_message_background(
|
||||
result = await agent_loop.step(
|
||||
messages,
|
||||
max_steps=max_steps,
|
||||
run_id=job_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=include_return_message_types,
|
||||
@@ -904,6 +908,7 @@ async def process_message_background(
|
||||
input_messages=messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
metadata={"job_id": job_id},
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
@@ -936,9 +941,8 @@ async def process_message_background(
|
||||
async def send_message_async(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
request: LettaAsyncRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
callback_url: Optional[str] = Query(None, description="Optional callback URL to POST to when the job completes"),
|
||||
):
|
||||
"""
|
||||
Asynchronously process a user message and return a run object.
|
||||
@@ -951,7 +955,7 @@ async def send_message_async(
|
||||
run = Run(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.created,
|
||||
callback_url=callback_url,
|
||||
callback_url=request.callback_url,
|
||||
metadata={
|
||||
"job_type": "send_message_async",
|
||||
"agent_id": agent_id,
|
||||
@@ -960,6 +964,7 @@ async def send_message_async(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
)
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
|
||||
@@ -1036,6 +1041,7 @@ async def summarize_agent_conversation(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
|
||||
@@ -92,7 +92,7 @@ async def list_run_messages(
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
order: str = Query(
|
||||
"desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
|
||||
"asc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
|
||||
),
|
||||
role: Optional[MessageRole] = Query(None, description="Filter by role"),
|
||||
):
|
||||
|
||||
@@ -1355,6 +1355,7 @@ class SyncServer(Server):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=self.step_manager,
|
||||
@@ -1996,6 +1997,7 @@ class SyncServer(Server):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=actor,
|
||||
sandbox_env_vars=tool_env_vars,
|
||||
|
||||
@@ -342,6 +342,33 @@ class JobManager:
|
||||
session.add(job_message)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def add_messages_to_job_async(self, job_id: str, message_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Associate a message with a job by creating a JobMessage record.
|
||||
Each message can only be associated with one job.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job
|
||||
message_id: The ID of the message to associate
|
||||
actor: The user making the request
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# First verify job exists and user has access
|
||||
await self._verify_job_access_async(session, job_id, actor, access=["write"])
|
||||
|
||||
# Create new JobMessage associations
|
||||
job_messages = [JobMessage(job_id=job_id, message_id=message_id) for message_id in message_ids]
|
||||
session.add_all(job_messages)
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics:
|
||||
@@ -463,14 +490,19 @@ class JobManager:
|
||||
)
|
||||
|
||||
request_config = self._get_run_request_config(run_id)
|
||||
print("request_config", request_config)
|
||||
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=messages,
|
||||
use_assistant_message=request_config["use_assistant_message"],
|
||||
assistant_message_tool_name=request_config["assistant_message_tool_name"],
|
||||
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
|
||||
reverse=not ascending,
|
||||
)
|
||||
|
||||
if request_config["include_return_message_types"]:
|
||||
messages = [msg for msg in messages if msg.message_type in request_config["include_return_message_types"]]
|
||||
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -101,6 +101,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
|
||||
@@ -49,6 +50,7 @@ class ToolExecutorFactory:
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
job_manager: JobManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
) -> ToolExecutor:
|
||||
@@ -58,6 +60,7 @@ class ToolExecutorFactory:
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
block_manager=block_manager,
|
||||
job_manager=job_manager,
|
||||
passage_manager=passage_manager,
|
||||
actor=actor,
|
||||
)
|
||||
@@ -71,6 +74,7 @@ class ToolExecutionManager:
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
job_manager: JobManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
@@ -80,6 +84,7 @@ class ToolExecutionManager:
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.block_manager = block_manager
|
||||
self.job_manager = job_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.agent_state = agent_state
|
||||
self.logger = get_logger(__name__)
|
||||
@@ -101,6 +106,7 @@ class ToolExecutionManager:
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
job_manager=self.job_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
|
||||
@@ -20,12 +21,14 @@ class ToolExecutor(ABC):
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
job_manager: JobManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
):
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.block_manager = block_manager
|
||||
self.job_manager = job_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.actor = actor
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
agent_state=agent_state,
|
||||
actor=default_user,
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta_client.types import (
|
||||
Base64Image,
|
||||
HiddenReasoningMessage,
|
||||
ImageContent,
|
||||
LettaMessageUnion,
|
||||
LettaStopReason,
|
||||
LettaUsageStatistics,
|
||||
ReasoningMessage,
|
||||
@@ -351,20 +352,24 @@ def accumulate_chunks(chunks: List[Any]) -> List[Any]:
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Asserts that a list of message dictionaries contains the expected types and statuses.
|
||||
def cast_message_dict_to_messages(messages: List[Dict[str, Any]]) -> List[LettaMessageUnion]:
|
||||
def cast_message(message: Dict[str, Any]) -> LettaMessageUnion:
|
||||
if message["message_type"] == "reasoning_message":
|
||||
return ReasoningMessage(**message)
|
||||
elif message["message_type"] == "assistant_message":
|
||||
return AssistantMessage(**message)
|
||||
elif message["message_type"] == "tool_call_message":
|
||||
return ToolCallMessage(**message)
|
||||
elif message["message_type"] == "tool_return_message":
|
||||
return ToolReturnMessage(**message)
|
||||
elif message["message_type"] == "user_message":
|
||||
return UserMessage(**message)
|
||||
elif message["message_type"] == "hidden_reasoning_message":
|
||||
return HiddenReasoningMessage(**message)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {message['message_type']}")
|
||||
|
||||
Expected order:
|
||||
1. reasoning_message
|
||||
2. tool_call_message
|
||||
3. tool_return_message (with status 'success')
|
||||
4. reasoning_message
|
||||
5. assistant_message
|
||||
"""
|
||||
assert isinstance(messages, list)
|
||||
assert messages[0]["message_type"] == "reasoning_message"
|
||||
assert messages[1]["message_type"] == "assistant_message"
|
||||
return [cast_message(message) for message in messages]
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@@ -870,6 +875,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "failed":
|
||||
print(run)
|
||||
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
|
||||
@@ -891,6 +897,7 @@ def test_async_greeting_with_assistant_message(
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
@@ -902,8 +909,86 @@ def test_async_greeting_with_assistant_message(
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
|
||||
messages = result["messages"]
|
||||
assert_tool_response_dict_messages(messages)
|
||||
messages = cast_message_dict_to_messages(result["messages"])
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_async_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
|
||||
messages = cast_message_dict_to_messages(result["messages"])
|
||||
assert_greeting_without_assistant_message_response(messages, llm_config=llm_config)
|
||||
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
assert_greeting_without_assistant_message_response(messages, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_async_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
|
||||
messages = cast_message_dict_to_messages(result["messages"])
|
||||
assert_tool_call_response(messages, llm_config=llm_config)
|
||||
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
assert_tool_call_response(messages, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
class CallbackServer:
|
||||
@@ -1021,8 +1106,9 @@ def test_async_greeting_with_callback_url(
|
||||
# Validate job completed successfully
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
messages = result["messages"]
|
||||
assert_tool_response_dict_messages(messages)
|
||||
|
||||
messages = cast_message_dict_to_messages(result["messages"])
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
|
||||
# Validate callback was received
|
||||
assert server.wait_for_callback(timeout=15), "Callback was not received within timeout"
|
||||
@@ -1084,8 +1170,9 @@ def test_async_callback_failure_scenarios(
|
||||
# Validate job completed successfully
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
messages = result["messages"]
|
||||
assert_tool_response_dict_messages(messages)
|
||||
|
||||
messages = cast_message_dict_to_messages(result["messages"])
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
|
||||
# Job should be marked as completed even if callback failed
|
||||
assert run.status == "completed", f"Expected status 'completed', got {run.status}"
|
||||
|
||||
@@ -110,6 +110,7 @@ async def test_provider_trace_experimental_step(message, agent_state, default_us
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
job_manager=JobManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
@@ -134,6 +135,7 @@ async def test_provider_trace_experimental_step_stream(message, agent_state, def
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
job_manager=JobManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
@@ -189,6 +191,7 @@ async def test_noop_provider_trace(message, agent_state, default_user, event_loo
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
job_manager=JobManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=NoopTelemetryManager(),
|
||||
|
||||
@@ -637,7 +637,7 @@ def test_many_blocks(client: LettaSDKClient):
|
||||
|
||||
|
||||
# cases: steam, async, token stream, sync
|
||||
@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync"])
|
||||
@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync", "async"])
|
||||
def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, message_create: str):
|
||||
"""Test that the include_return_message_types parameter works"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user