feat: integrate temporal into letta (#4766)

* feat: integrate temporal into letta

* use fire and forget, set up cancellation and job status checking
This commit is contained in:
cthomas
2025-09-18 13:45:01 -07:00
committed by Caren Thomas
parent 55b1e43e0c
commit 992f94da4b
4 changed files with 130 additions and 3 deletions

View File

@@ -0,0 +1,70 @@
from temporalio import Client
from letta.agents.base_agent_v2 import BaseAgentV2
from letta.agents.temporal.temporal_agent_workflow import TemporalAgentWorkflow
from letta.agents.temporal.types import WorkflowInputParams
from letta.constants import DEFAULT_MAX_STEPS
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.settings import settings
class TemporalAgent(BaseAgentV2):
"""
Execute the agent loop on temporal.
"""
def __init__(self, agent_state: AgentState, actor: User):
self.agent_state = agent_state
self.actor = actor
self.logger = get_logger(agent_state.id)
async def step(
self,
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
) -> LettaResponse:
"""
Execute the agent loop on temporal.
"""
if not run_id:
raise ValueError("run_id is required")
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=True,
)
await client.start_workflow(
TemporalAgentWorkflow.run,
"agent_loop_async",
id=run_id,
task_queue="agent_loop_async_task_queue",
arg=(
WorkflowInputParams(
agent_state=self.agent_state,
messages=input_messages,
actor=self.actor,
max_steps=max_steps,
run_id=run_id,
),
),
)
return LettaResponse(
messages=[],
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=LettaUsageStatistics(),
)

View File

@@ -1489,6 +1489,15 @@ async def cancel_agent_run(
results = {}
for run_id in run_ids:
run = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor)
if run.metadata.get("temporal") and settings.temporal_endpoint:
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=True,
)
await client.cancel_workflow(run_id)
success = await server.job_manager.safe_update_job_status_async(
job_id=run_id,
new_status=JobStatus.cancelled,
@@ -1645,6 +1654,7 @@ async def send_message_async(
metadata={
"job_type": "send_message_async",
"agent_id": agent_id,
"temporal": settings.temporal_endpoint != None,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
@@ -1655,6 +1665,17 @@ async def send_message_async(
)
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
if settings.temporal_endpoint:
temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor)
await temporal_agent.step(
input_messages=request.messages,
max_steps=request.max_steps,
run_id=run.id,
use_assistant_message=request.use_assistant_message,
include_return_message_types=request.include_return_message_types,
)
return run
# Create asyncio task for background processing (shielded to prevent cancellation)
task = safe_create_shielded_task(
_process_message_background(

View File

@@ -96,7 +96,7 @@ def list_active_runs(
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
def retrieve_run(
async def retrieve_run(
run_id: str,
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
@@ -104,10 +104,42 @@ def retrieve_run(
"""
Get the status of a run.
"""
actor = server.user_manager.get_user_or_default(user_id=headers.actor_id)
actor = await server.user_manager.get_actor_or_default_async(user_id=headers.actor_id)
try:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor)
if job.metadata.get("temporal") and settings.temporal_endpoint:
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=True,
)
handle = client.get_workflow_handle(workflow_id)
# Fetch the workflow description
desc = await handle.describe()
# Map the status to our enum
job_status = JobStatus.created
if desc.status.name == "RUNNING":
job_status = JobStatus.running
elif desc.status.name == "COMPLETED":
job_status = JobStatus.completed
elif desc.status.name == "FAILED":
job_status = JobStatus.failed
elif desc.status.name == "CANCELED":
job_status = JobStatus.canceled
# elif desc.status.name == "TERMINATED":
# job_status = JobStatus.terminated
# elif desc.status.name == "TIMED_OUT":
# job_status = JobStatus.timed_out
# elif desc.status.name == "CONTINUED_AS_NEW":
# return WorkflowStatus.CONTINUED_AS_NEW
# else:
# return WorkflowStatus.UNKNOWN
job.status = job_status
return Run.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")

View File

@@ -238,6 +238,10 @@ class Settings(BaseSettings):
redis_host: Optional[str] = Field(default=None, description="Host for Redis instance")
redis_port: Optional[int] = Field(default=6379, description="Port for Redis instance")
temporal_api_key: Optional[str] = Field(default=None, description="API key for Temporal instance")
temporal_namespace: Optional[str] = Field(default=None, description="Namespace for Temporal instance")
temporal_endpoint: Optional[str] = Field(default=None, description="Endpoint for Temporal instance")
plugin_register: Optional[str] = None
# multi agent settings