feat: add new lettuce client (#5207)
* feat: add new lettuce client * modest improvements * fix comment
This commit is contained in:
@@ -1,90 +0,0 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from temporalio.client import Client
|
||||
|
||||
from letta.agents.base_agent_v2 import BaseAgentV2
|
||||
from letta.agents.temporal.temporal_agent_workflow import TemporalAgentWorkflow
|
||||
from letta.agents.temporal.types import WorkflowInputParams
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
class TemporalAgent(BaseAgentV2):
|
||||
"""
|
||||
Execute the agent loop on temporal.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_state: AgentState, actor: User):
|
||||
self.agent_state = agent_state
|
||||
self.actor = actor
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
async def step(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop on temporal.
|
||||
"""
|
||||
if not run_id:
|
||||
raise ValueError("run_id is required")
|
||||
|
||||
client = await Client.connect(
|
||||
settings.temporal_endpoint,
|
||||
namespace=settings.temporal_namespace,
|
||||
api_key=settings.temporal_api_key,
|
||||
tls=settings.temporal_tls, # This should be false for local runs
|
||||
)
|
||||
|
||||
workflow_input = WorkflowInputParams(
|
||||
agent_state=self.agent_state,
|
||||
messages=input_messages,
|
||||
actor=self.actor,
|
||||
max_steps=max_steps,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
await client.start_workflow(
|
||||
TemporalAgentWorkflow.run,
|
||||
workflow_input,
|
||||
id=run_id,
|
||||
task_queue=settings.temporal_task_queue,
|
||||
)
|
||||
|
||||
return LettaResponse(
|
||||
messages=[],
|
||||
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
|
||||
usage=LettaUsageStatistics(),
|
||||
)
|
||||
|
||||
async def build_request(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
stream_tokens: bool = False,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
raise NotImplementedError
|
||||
@@ -11,11 +11,9 @@ from orjson import orjson
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from temporalio.client import Client
|
||||
|
||||
from letta.agents.agent_loop import AgentLoop
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.temporal_agent import TemporalAgent
|
||||
from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import (
|
||||
@@ -58,6 +56,7 @@ from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.lettuce.lettuce_client import LettuceClient
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_shielded_task, safe_create_task, truncate_file_visible_content
|
||||
@@ -1517,13 +1516,8 @@ async def cancel_agent_run(
|
||||
for run_id in run_ids:
|
||||
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.metadata.get("lettuce") and settings.temporal_endpoint:
|
||||
client = await Client.connect(
|
||||
settings.temporal_endpoint,
|
||||
namespace=settings.temporal_namespace,
|
||||
api_key=settings.temporal_api_key,
|
||||
tls=True, # This should be false for local runs
|
||||
)
|
||||
await client.cancel_workflow(run_id)
|
||||
lettuce_client = await LettuceClient.create()
|
||||
await lettuce_client.cancel(run_id)
|
||||
success = await server.run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.cancelled),
|
||||
@@ -1695,15 +1689,18 @@ async def send_message_async(
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
)
|
||||
if agent_state.multi_agent_group is None and agent_state.agent_type != AgentType.letta_v1_agent:
|
||||
temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor)
|
||||
await temporal_agent.step(
|
||||
lettuce_client = LettuceClient.create()
|
||||
run_id_from_lettuce = await lettuce_client.step(
|
||||
agent_state=agent_state,
|
||||
actor=actor,
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
run_id=run.id,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
return run
|
||||
if run_id_from_lettuce:
|
||||
return run
|
||||
|
||||
# Create asyncio task for background processing (shielded to prevent cancellation)
|
||||
task = safe_create_shielded_task(
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Annotated, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import Field
|
||||
from temporalio.client import Client
|
||||
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@@ -23,6 +22,7 @@ from letta.server.rest_api.streaming_response import (
|
||||
cancellation_aware_stream_wrapper,
|
||||
)
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.lettuce.lettuce_client import LettuceClient
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -136,26 +136,18 @@ async def retrieve_run(
|
||||
|
||||
use_lettuce = run.metadata and run.metadata.get("lettuce") and settings.temporal_endpoint
|
||||
if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]:
|
||||
client = await Client.connect(
|
||||
settings.temporal_endpoint,
|
||||
namespace=settings.temporal_namespace,
|
||||
api_key=settings.temporal_api_key,
|
||||
tls=settings.temporal_tls, # This should be false for local runs
|
||||
)
|
||||
handle = client.get_workflow_handle(run_id)
|
||||
|
||||
# Fetch the workflow description
|
||||
desc = await handle.describe()
|
||||
lettuce_client = await LettuceClient.create()
|
||||
status = await lettuce_client.get_status()
|
||||
|
||||
# Map the status to our enum
|
||||
run_status = RunStatus.created
|
||||
if desc.status.name == "RUNNING":
|
||||
if status == "RUNNING":
|
||||
run_status = RunStatus.running
|
||||
elif desc.status.name == "COMPLETED":
|
||||
elif status == "COMPLETED":
|
||||
run_status = RunStatus.completed
|
||||
elif desc.status.name == "FAILED":
|
||||
elif status == "FAILED":
|
||||
run_status = RunStatus.failed
|
||||
elif desc.status.name == "CANCELED":
|
||||
elif status == "CANCELLED":
|
||||
run_status = RunStatus.cancelled
|
||||
run.status = run_status
|
||||
return run
|
||||
|
||||
86
letta/services/lettuce/lettuce_client_base.py
Normal file
86
letta/services/lettuce/lettuce_client_base.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
|
||||
|
||||
class LettuceClient:
|
||||
"""Base class for LettuceClient."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the LettuceClient."""
|
||||
self.client: None = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls) -> "LettuceClient":
|
||||
"""
|
||||
Asynchronously creates the client.
|
||||
|
||||
Returns:
|
||||
LettuceClient: The created LettuceClient instance.
|
||||
"""
|
||||
instance = cls()
|
||||
return instance
|
||||
|
||||
def get_client(self) -> None:
|
||||
"""
|
||||
Get the inner client.
|
||||
|
||||
Returns:
|
||||
None: The inner client.
|
||||
"""
|
||||
return self.client
|
||||
|
||||
async def get_status(self, run_id: str) -> str | None:
|
||||
"""
|
||||
Get the status of a run.
|
||||
|
||||
Args:
|
||||
run_id (str): The ID of the run.
|
||||
|
||||
Returns:
|
||||
str | None: The status of the run or None if not available.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def cancel(self, run_id: str) -> str | None:
|
||||
"""
|
||||
Cancel a run.
|
||||
|
||||
Args:
|
||||
run_id (str): The ID of the run to cancel.
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the canceled run or None if not available.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def step(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Execute the agent loop on Lettuce service.
|
||||
|
||||
Args:
|
||||
agent_state (AgentState): The state of the agent.
|
||||
actor (User): The actor.
|
||||
input_messages (list[MessageCreate]): The input messages.
|
||||
max_steps (int, optional): The maximum number of steps. Defaults to DEFAULT_MAX_STEPS.
|
||||
run_id (str | None, optional): The ID of the run. Defaults to None.
|
||||
use_assistant_message (bool, optional): Whether to use the assistant message. Defaults to True.
|
||||
include_return_message_types (list[MessageType] | None, optional): The message types to include in the return. Defaults to None.
|
||||
request_start_timestamp_ns (int | None, optional): The start timestamp of the request. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the run or None if client is not available.
|
||||
"""
|
||||
return None
|
||||
Reference in New Issue
Block a user