diff --git a/letta/agents/temporal_agent.py b/letta/agents/temporal_agent.py new file mode 100644 index 00000000..693cd28d --- /dev/null +++ b/letta/agents/temporal_agent.py @@ -0,0 +1,70 @@ +from temporalio import Client + +from letta.agents.base_agent_v2 import BaseAgentV2 +from letta.agents.temporal.temporal_agent_workflow import TemporalAgentWorkflow +from letta.agents.temporal.types import WorkflowInputParams +from letta.constants import DEFAULT_MAX_STEPS +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.letta_message import MessageType +from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType +from letta.schemas.message import MessageCreate +from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.user import User +from letta.settings import settings + + +class TemporalAgent(BaseAgentV2): + """ + Execute the agent loop on temporal. + """ + + def __init__(self, agent_state: AgentState, actor: User): + self.agent_state = agent_state + self.actor = actor + self.logger = get_logger(agent_state.id) + + async def step( + self, + input_messages: list[MessageCreate], + max_steps: int = DEFAULT_MAX_STEPS, + run_id: str | None = None, + use_assistant_message: bool = True, + include_return_message_types: list[MessageType] | None = None, + request_start_timestamp_ns: int | None = None, + ) -> LettaResponse: + """ + Execute the agent loop on temporal. + """ + if not run_id: + raise ValueError("run_id is required") + + client = await Client.connect( + settings.temporal_endpoint, + namespace=settings.temporal_namespace, + api_key=settings.temporal_api_key, + tls=True, + ) + + await client.start_workflow( + TemporalAgentWorkflow.run, + "agent_loop_async", + id=run_id, + task_queue="agent_loop_async_task_queue", + arg=( + WorkflowInputParams( + agent_state=self.agent_state, + messages=input_messages, + actor=self.actor, + max_steps=max_steps, + run_id=run_id, + ), + ), + ) + + return LettaResponse( + messages=[], + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=LettaUsageStatistics(), + ) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 01d3ab40..139ea770 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1489,6 +1489,15 @@ async def cancel_agent_run( results = {} for run_id in run_ids: + run = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor) + if run.metadata.get("temporal") and settings.temporal_endpoint: + client = await Client.connect( + settings.temporal_endpoint, + namespace=settings.temporal_namespace, + api_key=settings.temporal_api_key, + tls=True, + ) + await client.cancel_workflow(run_id) success = await server.job_manager.safe_update_job_status_async( job_id=run_id, new_status=JobStatus.cancelled, @@ -1645,6 +1654,7 @@ async def send_message_async( metadata={ "job_type": "send_message_async", "agent_id": agent_id, + "temporal": settings.temporal_endpoint != None, }, request_config=LettaRequestConfig( use_assistant_message=request.use_assistant_message, @@ -1655,6 +1665,17 @@ async def send_message_async( ) run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor) + if settings.temporal_endpoint: + temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor) + await temporal_agent.step( + input_messages=request.messages, + max_steps=request.max_steps, + run_id=run.id, + use_assistant_message=request.use_assistant_message, + include_return_message_types=request.include_return_message_types, + ) + return run + # Create asyncio task for background processing (shielded to prevent cancellation) task = safe_create_shielded_task( _process_message_background( diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index a622dba6..8d28ff31 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -96,7 +96,7 @@ def list_active_runs( @router.get("/{run_id}", response_model=Run, operation_id="retrieve_run") -def retrieve_run( +async def retrieve_run( run_id: str, headers: HeaderParams = Depends(get_headers), server: "SyncServer" = Depends(get_letta_server), @@ -104,10 +104,42 @@ def retrieve_run( """ Get the status of a run. """ - actor = server.user_manager.get_user_or_default(user_id=headers.actor_id) + actor = await server.user_manager.get_actor_or_default_async(user_id=headers.actor_id) try: - job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) + job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor) + + if job.metadata.get("temporal") and settings.temporal_endpoint: + client = await Client.connect( + settings.temporal_endpoint, + namespace=settings.temporal_namespace, + api_key=settings.temporal_api_key, + tls=True, + ) + handle = client.get_workflow_handle(workflow_id) + + # Fetch the workflow description + desc = await handle.describe() + + # Map the status to our enum + job_status = JobStatus.created + if desc.status.name == "RUNNING": + job_status = JobStatus.running + elif desc.status.name == "COMPLETED": + job_status = JobStatus.completed + elif desc.status.name == "FAILED": + job_status = JobStatus.failed + elif desc.status.name == "CANCELED": + job_status = JobStatus.canceled + # elif desc.status.name == "TERMINATED": + # job_status = JobStatus.terminated + # elif desc.status.name == "TIMED_OUT": + # job_status = JobStatus.timed_out + # elif desc.status.name == "CONTINUED_AS_NEW": + # return WorkflowStatus.CONTINUED_AS_NEW + # else: + # return WorkflowStatus.UNKNOWN + job.status = job_status return Run.from_job(job) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") diff --git a/letta/settings.py b/letta/settings.py index c6dd93fc..2baf4ad6 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -238,6 +238,10 @@ class Settings(BaseSettings): redis_host: Optional[str] = Field(default=None, description="Host for Redis instance") redis_port: Optional[int] = Field(default=6379, description="Port for Redis instance") + temporal_api_key: Optional[str] = Field(default=None, description="API key for Temporal instance") + temporal_namespace: Optional[str] = Field(default=None, description="Namespace for Temporal instance") + temporal_endpoint: Optional[str] = Field(default=None, description="Endpoint for Temporal instance") + plugin_register: Optional[str] = None # multi agent settings