From fc980ff65421dcda3842185debf02c13a8857d31 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 9 Dec 2024 18:23:05 -0800 Subject: [PATCH] feat: add an async messages route (`/agent/{agent_id}/messages/async`) (#2206) --- letta/client/client.py | 34 ++++++- letta/server/rest_api/routers/v1/agents.py | 101 ++++++++++++++++++++- tests/test_client.py | 26 ++++++ 3 files changed, 159 insertions(+), 2 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index 7e8ec304..97f36cd9 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -960,7 +960,6 @@ class RESTClient(AbstractClient): # TODO: figure out how to handle stream_steps and stream_tokens # When streaming steps is True, stream_tokens must be False - request = LettaRequest(messages=messages) if stream_tokens or stream_steps: from letta.client.streaming import _sse_post @@ -985,6 +984,39 @@ class RESTClient(AbstractClient): return response + def send_message_async( + self, + message: str, + role: str, + agent_id: Optional[str] = None, + name: Optional[str] = None, + ) -> Job: + """ + Send a message to an agent (async, returns a job) + + Args: + message (str): Message to send + role (str): Role of the message + agent_id (str): ID of the agent + name(str): Name of the sender + + Returns: + job (Job): Information about the async job + """ + messages = [MessageCreate(role=MessageRole(role), text=message, name=name)] + + request = LettaRequest(messages=messages) + response = requests.post( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/async", + json=request.model_dump(), + headers=self.headers, + ) + if response.status_code != 200: + raise ValueError(f"Failed to send message: {response.text}") + response = Job(**response.json()) + + return response + # humans / personas def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]: diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index e3922a28..e7f68dc9 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -3,7 +3,16 @@ import warnings from datetime import datetime from typing import List, Optional, Union -from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Depends, + Header, + HTTPException, + Query, + status, +) from fastapi.responses import JSONResponse, StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG @@ -14,6 +23,7 @@ from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate CreateBlock, ) from letta.schemas.enums import MessageStreamStatus +from letta.schemas.job import Job, JobStatus, JobUpdate from letta.schemas.letta_message import ( LegacyLettaMessage, LettaMessage, @@ -32,6 +42,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.passage import Passage from letta.schemas.source import Source from letta.schemas.tool import Tool +from letta.schemas.user import User from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.utils import get_letta_server, sse_async_generator from letta.server.server import SyncServer @@ -502,6 +513,94 @@ async def send_message_streaming( return result +async def process_message_background( + job_id: str, + server: SyncServer, + actor: User, + agent_id: str, + user_id: str, + messages: list, + assistant_message_tool_name: str, + assistant_message_tool_kwarg: str, +) -> None: + """Background task to process the message and update job status.""" + try: + # TODO(matt) we should probably make this stream_steps and log each step as it progresses, so the job update GET can see the total steps so far + partial usage? + result = await send_message_to_agent( + server=server, + agent_id=agent_id, + user_id=user_id, + messages=messages, + stream_steps=False, # NOTE(matt) + stream_tokens=False, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + ) + + # Update job status to completed + job_update = JobUpdate( + status=JobStatus.completed, + completed_at=datetime.utcnow(), + metadata_={"result": result.model_dump()}, # Store the result in metadata + ) + server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor) + + except Exception as e: + # Update job status to failed + job_update = JobUpdate( + status=JobStatus.failed, + completed_at=datetime.utcnow(), + metadata_={"error": str(e)}, + ) + server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor) + raise + + +@router.post( + "/{agent_id}/messages/async", + response_model=Job, + operation_id="create_agent_message_async", +) +async def send_message_async( + agent_id: str, + background_tasks: BackgroundTasks, + server: SyncServer = Depends(get_letta_server), + request: LettaRequest = Body(...), + user_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Asynchronously process a user message and return a job ID. + The actual processing happens in the background, and the status can be checked using the job ID. + """ + actor = server.get_user_or_default(user_id=user_id) + + # Create a new job + job = Job( + user_id=actor.id, + status=JobStatus.created, + metadata_={ + "job_type": "send_message_async", + "agent_id": agent_id, + }, + ) + job = server.job_manager.create_job(pydantic_job=job, actor=actor) + + # Add the background task + background_tasks.add_task( + process_message_background, + job_id=job.id, + server=server, + actor=actor, + agent_id=agent_id, + user_id=actor.id, + messages=request.messages, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + ) + + return job + + # TODO: move this into server.py? async def send_message_to_agent( server: SyncServer, diff --git a/tests/test_client.py b/tests/test_client.py index 866bd201..61a95f3a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -14,6 +14,7 @@ from letta.orm import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.job import JobStatus from letta.schemas.llm_config import LLMConfig from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType from letta.utils import create_random_username @@ -328,3 +329,28 @@ async def test_send_message_parallel(client: Union[LocalClient, RESTClient], age # Ensure both tasks completed assert len(responses) == len(messages), "Not all messages were processed" + + +def test_send_message_async(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can send a message asynchronously""" + + if not isinstance(client, RESTClient): + pytest.skip("send_message_async is only supported by the RESTClient") + + print("Sending message asynchronously") + job = client.send_message_async(agent_id=agent.id, role="user", message="This is a test message, no need to respond.") + assert job.id is not None + assert job.status == JobStatus.created + print(f"Job created, job={job}, status={job.status}") + + # Wait for the job to complete, cancel it if takes over 10 seconds + start_time = time.time() + while job.status == JobStatus.created: + time.sleep(1) + job = client.get_job(job_id=job.id) + print(f"Job status: {job.status}") + if time.time() - start_time > 10: + pytest.fail("Job took too long to complete") + + print(f"Job completed in {time.time() - start_time} seconds, job={job}") + assert job.status == JobStatus.completed