diff --git a/letta/agents/temporal_agent.py b/letta/agents/temporal_agent.py deleted file mode 100644 index c28d1c33..00000000 --- a/letta/agents/temporal_agent.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import AsyncGenerator - -from temporalio.client import Client - -from letta.agents.base_agent_v2 import BaseAgentV2 -from letta.agents.temporal.temporal_agent_workflow import TemporalAgentWorkflow -from letta.agents.temporal.types import WorkflowInputParams -from letta.constants import DEFAULT_MAX_STEPS -from letta.log import get_logger -from letta.schemas.agent import AgentState -from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType -from letta.schemas.letta_response import LettaResponse -from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType -from letta.schemas.message import MessageCreate -from letta.schemas.usage import LettaUsageStatistics -from letta.schemas.user import User -from letta.settings import settings - - -class TemporalAgent(BaseAgentV2): - """ - Execute the agent loop on temporal. - """ - - def __init__(self, agent_state: AgentState, actor: User): - self.agent_state = agent_state - self.actor = actor - self.logger = get_logger(agent_state.id) - - async def step( - self, - input_messages: list[MessageCreate], - max_steps: int = DEFAULT_MAX_STEPS, - run_id: str | None = None, - use_assistant_message: bool = True, - include_return_message_types: list[MessageType] | None = None, - request_start_timestamp_ns: int | None = None, - ) -> LettaResponse: - """ - Execute the agent loop on temporal. - """ - if not run_id: - raise ValueError("run_id is required") - - client = await Client.connect( - settings.temporal_endpoint, - namespace=settings.temporal_namespace, - api_key=settings.temporal_api_key, - tls=settings.temporal_tls, # This should be false for local runs - ) - - workflow_input = WorkflowInputParams( - agent_state=self.agent_state, - messages=input_messages, - actor=self.actor, - max_steps=max_steps, - run_id=run_id, - ) - - await client.start_workflow( - TemporalAgentWorkflow.run, - workflow_input, - id=run_id, - task_queue=settings.temporal_task_queue, - ) - - return LettaResponse( - messages=[], - stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), - usage=LettaUsageStatistics(), - ) - - async def build_request( - self, - input_messages: list[MessageCreate], - ) -> dict: - raise NotImplementedError - - async def stream( - self, - input_messages: list[MessageCreate], - max_steps: int = DEFAULT_MAX_STEPS, - stream_tokens: bool = False, - run_id: str | None = None, - use_assistant_message: bool = True, - include_return_message_types: list[MessageType] | None = None, - request_start_timestamp_ns: int | None = None, - ) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]: - raise NotImplementedError diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index cc7e96a4..c4de97b5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -11,11 +11,9 @@ from orjson import orjson from pydantic import BaseModel, Field from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse -from temporalio.client import Client from letta.agents.agent_loop import AgentLoop from letta.agents.letta_agent_v2 import LettaAgentV2 -from letta.agents.temporal_agent import TemporalAgent from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.errors import ( @@ -58,6 +56,7 @@ from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator from letta.server.server import SyncServer +from letta.services.lettuce.lettuce_client import LettuceClient from letta.services.run_manager import RunManager from letta.settings import settings from letta.utils import safe_create_shielded_task, safe_create_task, truncate_file_visible_content @@ -1517,13 +1516,8 @@ async def cancel_agent_run( for run_id in run_ids: run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor) if run.metadata.get("lettuce") and settings.temporal_endpoint: - client = await Client.connect( - settings.temporal_endpoint, - namespace=settings.temporal_namespace, - api_key=settings.temporal_api_key, - tls=True, # This should be false for local runs - ) - await client.cancel_workflow(run_id) + lettuce_client = await LettuceClient.create() + await lettuce_client.cancel(run_id) success = await server.run_manager.update_run_by_id_async( run_id=run_id, update=RunUpdate(status=RunStatus.cancelled), @@ -1695,15 +1689,18 @@ async def send_message_async( agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"] ) if agent_state.multi_agent_group is None and agent_state.agent_type != AgentType.letta_v1_agent: - temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor) - await temporal_agent.step( + lettuce_client = LettuceClient.create() + run_id_from_lettuce = await lettuce_client.step( + agent_state=agent_state, + actor=actor, input_messages=request.messages, max_steps=request.max_steps, run_id=run.id, use_assistant_message=request.use_assistant_message, include_return_message_types=request.include_return_message_types, ) - return run + if run_id_from_lettuce: + return run # Create asyncio task for background processing (shielded to prevent cancellation) task = safe_create_shielded_task( diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 3f406702..e864b0bb 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -3,7 +3,6 @@ from typing import Annotated, List, Literal, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query from pydantic import Field -from temporalio.client import Client from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.helpers.datetime_helpers import get_utc_time @@ -23,6 +22,7 @@ from letta.server.rest_api.streaming_response import ( cancellation_aware_stream_wrapper, ) from letta.server.server import SyncServer +from letta.services.lettuce.lettuce_client import LettuceClient from letta.services.run_manager import RunManager from letta.settings import settings @@ -136,26 +136,18 @@ async def retrieve_run( use_lettuce = run.metadata and run.metadata.get("lettuce") and settings.temporal_endpoint if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]: - client = await Client.connect( - settings.temporal_endpoint, - namespace=settings.temporal_namespace, - api_key=settings.temporal_api_key, - tls=settings.temporal_tls, # This should be false for local runs - ) - handle = client.get_workflow_handle(run_id) - - # Fetch the workflow description - desc = await handle.describe() + lettuce_client = await LettuceClient.create() + status = await lettuce_client.get_status() # Map the status to our enum run_status = RunStatus.created - if desc.status.name == "RUNNING": + if status == "RUNNING": run_status = RunStatus.running - elif desc.status.name == "COMPLETED": + elif status == "COMPLETED": run_status = RunStatus.completed - elif desc.status.name == "FAILED": + elif status == "FAILED": run_status = RunStatus.failed - elif desc.status.name == "CANCELED": + elif status == "CANCELLED": run_status = RunStatus.cancelled run.status = run_status return run diff --git a/letta/services/lettuce/lettuce_client_base.py b/letta/services/lettuce/lettuce_client_base.py new file mode 100644 index 00000000..45609b0a --- /dev/null +++ b/letta/services/lettuce/lettuce_client_base.py @@ -0,0 +1,86 @@ +from letta.constants import DEFAULT_MAX_STEPS +from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageType +from letta.schemas.message import MessageCreate +from letta.schemas.user import User + + +class LettuceClient: + """Base class for LettuceClient.""" + + def __init__(self): + """Initialize the LettuceClient.""" + self.client: None = None + + @classmethod + async def create(cls) -> "LettuceClient": + """ + Asynchronously creates the client. + + Returns: + LettuceClient: The created LettuceClient instance. + """ + instance = cls() + return instance + + def get_client(self) -> None: + """ + Get the inner client. + + Returns: + None: The inner client. + """ + return self.client + + async def get_status(self, run_id: str) -> str | None: + """ + Get the status of a run. + + Args: + run_id (str): The ID of the run. + + Returns: + str | None: The status of the run or None if not available. + """ + return None + + async def cancel(self, run_id: str) -> str | None: + """ + Cancel a run. + + Args: + run_id (str): The ID of the run to cancel. + + Returns: + str | None: The ID of the canceled run or None if not available. + """ + return None + + async def step( + self, + agent_state: AgentState, + actor: User, + input_messages: list[MessageCreate], + max_steps: int = DEFAULT_MAX_STEPS, + run_id: str | None = None, + use_assistant_message: bool = True, + include_return_message_types: list[MessageType] | None = None, + request_start_timestamp_ns: int | None = None, + ) -> str | None: + """ + Execute the agent loop on Lettuce service. + + Args: + agent_state (AgentState): The state of the agent. + actor (User): The actor. + input_messages (list[MessageCreate]): The input messages. + max_steps (int, optional): The maximum number of steps. Defaults to DEFAULT_MAX_STEPS. + run_id (str | None, optional): The ID of the run. Defaults to None. + use_assistant_message (bool, optional): Whether to use the assistant message. Defaults to True. + include_return_message_types (list[MessageType] | None, optional): The message types to include in the return. Defaults to None. + request_start_timestamp_ns (int | None, optional): The start timestamp of the request. Defaults to None. + + Returns: + str | None: The ID of the run or None if client is not available. + """ + return None