feat: add new lettuce client (#5207)

* feat: add new lettuce client

* modest improvements

* fix comment
This commit is contained in:
cthomas
2025-10-07 11:28:52 -07:00
committed by Caren Thomas
parent 20ce885e07
commit 5b7e2b3f86
4 changed files with 102 additions and 117 deletions

View File

@@ -1,90 +0,0 @@
from typing import AsyncGenerator
from temporalio.client import Client
from letta.agents.base_agent_v2 import BaseAgentV2
from letta.agents.temporal.temporal_agent_workflow import TemporalAgentWorkflow
from letta.agents.temporal.types import WorkflowInputParams
from letta.constants import DEFAULT_MAX_STEPS
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.settings import settings
class TemporalAgent(BaseAgentV2):
"""
Execute the agent loop on temporal.
"""
def __init__(self, agent_state: AgentState, actor: User):
self.agent_state = agent_state
self.actor = actor
self.logger = get_logger(agent_state.id)
async def step(
self,
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
) -> LettaResponse:
"""
Execute the agent loop on temporal.
"""
if not run_id:
raise ValueError("run_id is required")
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=settings.temporal_tls, # This should be false for local runs
)
workflow_input = WorkflowInputParams(
agent_state=self.agent_state,
messages=input_messages,
actor=self.actor,
max_steps=max_steps,
run_id=run_id,
)
await client.start_workflow(
TemporalAgentWorkflow.run,
workflow_input,
id=run_id,
task_queue=settings.temporal_task_queue,
)
return LettaResponse(
messages=[],
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=LettaUsageStatistics(),
)
async def build_request(
self,
input_messages: list[MessageCreate],
) -> dict:
raise NotImplementedError
async def stream(
self,
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
stream_tokens: bool = False,
run_id: str | None = None,
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
raise NotImplementedError

View File

@@ -11,11 +11,9 @@ from orjson import orjson
from pydantic import BaseModel, Field
from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from temporalio.client import Client
from letta.agents.agent_loop import AgentLoop
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.temporal_agent import TemporalAgent
from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import (
@@ -58,6 +56,7 @@ from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
from letta.server.server import SyncServer
from letta.services.lettuce.lettuce_client import LettuceClient
from letta.services.run_manager import RunManager
from letta.settings import settings
from letta.utils import safe_create_shielded_task, safe_create_task, truncate_file_visible_content
@@ -1517,13 +1516,8 @@ async def cancel_agent_run(
for run_id in run_ids:
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.metadata.get("lettuce") and settings.temporal_endpoint:
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=True, # This should be false for local runs
)
await client.cancel_workflow(run_id)
lettuce_client = await LettuceClient.create()
await lettuce_client.cancel(run_id)
success = await server.run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.cancelled),
@@ -1695,15 +1689,18 @@ async def send_message_async(
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
)
if agent_state.multi_agent_group is None and agent_state.agent_type != AgentType.letta_v1_agent:
temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor)
await temporal_agent.step(
lettuce_client = LettuceClient.create()
run_id_from_lettuce = await lettuce_client.step(
agent_state=agent_state,
actor=actor,
input_messages=request.messages,
max_steps=request.max_steps,
run_id=run.id,
use_assistant_message=request.use_assistant_message,
include_return_message_types=request.include_return_message_types,
)
return run
if run_id_from_lettuce:
return run
# Create asyncio task for background processing (shielded to prevent cancellation)
task = safe_create_shielded_task(

View File

@@ -3,7 +3,6 @@ from typing import Annotated, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import Field
from temporalio.client import Client
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.helpers.datetime_helpers import get_utc_time
@@ -23,6 +22,7 @@ from letta.server.rest_api.streaming_response import (
cancellation_aware_stream_wrapper,
)
from letta.server.server import SyncServer
from letta.services.lettuce.lettuce_client import LettuceClient
from letta.services.run_manager import RunManager
from letta.settings import settings
@@ -136,26 +136,18 @@ async def retrieve_run(
use_lettuce = run.metadata and run.metadata.get("lettuce") and settings.temporal_endpoint
if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]:
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=settings.temporal_tls, # This should be false for local runs
)
handle = client.get_workflow_handle(run_id)
# Fetch the workflow description
desc = await handle.describe()
lettuce_client = await LettuceClient.create()
status = await lettuce_client.get_status()
# Map the status to our enum
run_status = RunStatus.created
if desc.status.name == "RUNNING":
if status == "RUNNING":
run_status = RunStatus.running
elif desc.status.name == "COMPLETED":
elif status == "COMPLETED":
run_status = RunStatus.completed
elif desc.status.name == "FAILED":
elif status == "FAILED":
run_status = RunStatus.failed
elif desc.status.name == "CANCELED":
elif status == "CANCELLED":
run_status = RunStatus.cancelled
run.status = run_status
return run

View File

@@ -0,0 +1,86 @@
from letta.constants import DEFAULT_MAX_STEPS
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
class LettuceClient:
"""Base class for LettuceClient."""
def __init__(self):
"""Initialize the LettuceClient."""
self.client: None = None
@classmethod
async def create(cls) -> "LettuceClient":
"""
Asynchronously creates the client.
Returns:
LettuceClient: The created LettuceClient instance.
"""
instance = cls()
return instance
def get_client(self) -> None:
"""
Get the inner client.
Returns:
None: The inner client.
"""
return self.client
async def get_status(self, run_id: str) -> str | None:
"""
Get the status of a run.
Args:
run_id (str): The ID of the run.
Returns:
str | None: The status of the run or None if not available.
"""
return None
async def cancel(self, run_id: str) -> str | None:
"""
Cancel a run.
Args:
run_id (str): The ID of the run to cancel.
Returns:
str | None: The ID of the canceled run or None if not available.
"""
return None
async def step(
self,
agent_state: AgentState,
actor: User,
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
) -> str | None:
"""
Execute the agent loop on Lettuce service.
Args:
agent_state (AgentState): The state of the agent.
actor (User): The actor.
input_messages (list[MessageCreate]): The input messages.
max_steps (int, optional): The maximum number of steps. Defaults to DEFAULT_MAX_STEPS.
run_id (str | None, optional): The ID of the run. Defaults to None.
use_assistant_message (bool, optional): Whether to use the assistant message. Defaults to True.
include_return_message_types (list[MessageType] | None, optional): The message types to include in the return. Defaults to None.
request_start_timestamp_ns (int | None, optional): The start timestamp of the request. Defaults to None.
Returns:
str | None: The ID of the run or None if client is not available.
"""
return None