feat: integrate temporal into letta (#4766)
* feat: integrate temporal into letta * use fire and forget, set up cancellation and job status checking
This commit is contained in:
70
letta/agents/temporal_agent.py
Normal file
70
letta/agents/temporal_agent.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from temporalio import Client
|
||||
|
||||
from letta.agents.base_agent_v2 import BaseAgentV2
|
||||
from letta.agents.temporal.temporal_agent_workflow import TemporalAgentWorkflow
|
||||
from letta.agents.temporal.types import WorkflowInputParams
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
class TemporalAgent(BaseAgentV2):
|
||||
"""
|
||||
Execute the agent loop on temporal.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_state: AgentState, actor: User):
|
||||
self.agent_state = agent_state
|
||||
self.actor = actor
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
async def step(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop on temporal.
|
||||
"""
|
||||
if not run_id:
|
||||
raise ValueError("run_id is required")
|
||||
|
||||
client = await Client.connect(
|
||||
settings.temporal_endpoint,
|
||||
namespace=settings.temporal_namespace,
|
||||
api_key=settings.temporal_api_key,
|
||||
tls=True,
|
||||
)
|
||||
|
||||
await client.start_workflow(
|
||||
TemporalAgentWorkflow.run,
|
||||
"agent_loop_async",
|
||||
id=run_id,
|
||||
task_queue="agent_loop_async_task_queue",
|
||||
arg=(
|
||||
WorkflowInputParams(
|
||||
agent_state=self.agent_state,
|
||||
messages=input_messages,
|
||||
actor=self.actor,
|
||||
max_steps=max_steps,
|
||||
run_id=run_id,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return LettaResponse(
|
||||
messages=[],
|
||||
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
|
||||
usage=LettaUsageStatistics(),
|
||||
)
|
||||
@@ -1489,6 +1489,15 @@ async def cancel_agent_run(
|
||||
|
||||
results = {}
|
||||
for run_id in run_ids:
|
||||
run = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor)
|
||||
if run.metadata.get("temporal") and settings.temporal_endpoint:
|
||||
client = await Client.connect(
|
||||
settings.temporal_endpoint,
|
||||
namespace=settings.temporal_namespace,
|
||||
api_key=settings.temporal_api_key,
|
||||
tls=True,
|
||||
)
|
||||
await client.cancel_workflow(run_id)
|
||||
success = await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run_id,
|
||||
new_status=JobStatus.cancelled,
|
||||
@@ -1645,6 +1654,7 @@ async def send_message_async(
|
||||
metadata={
|
||||
"job_type": "send_message_async",
|
||||
"agent_id": agent_id,
|
||||
"temporal": settings.temporal_endpoint != None,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
@@ -1655,6 +1665,17 @@ async def send_message_async(
|
||||
)
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
|
||||
|
||||
if settings.temporal_endpoint:
|
||||
temporal_agent = TemporalAgent(agent_state=agent_state, actor=actor)
|
||||
await temporal_agent.step(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
run_id=run.id,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
return run
|
||||
|
||||
# Create asyncio task for background processing (shielded to prevent cancellation)
|
||||
task = safe_create_shielded_task(
|
||||
_process_message_background(
|
||||
|
||||
@@ -96,7 +96,7 @@ def list_active_runs(
|
||||
|
||||
|
||||
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
|
||||
def retrieve_run(
|
||||
async def retrieve_run(
|
||||
run_id: str,
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -104,10 +104,42 @@ def retrieve_run(
|
||||
"""
|
||||
Get the status of a run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=headers.actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(user_id=headers.actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor)
|
||||
|
||||
if job.metadata.get("temporal") and settings.temporal_endpoint:
|
||||
client = await Client.connect(
|
||||
settings.temporal_endpoint,
|
||||
namespace=settings.temporal_namespace,
|
||||
api_key=settings.temporal_api_key,
|
||||
tls=True,
|
||||
)
|
||||
handle = client.get_workflow_handle(workflow_id)
|
||||
|
||||
# Fetch the workflow description
|
||||
desc = await handle.describe()
|
||||
|
||||
# Map the status to our enum
|
||||
job_status = JobStatus.created
|
||||
if desc.status.name == "RUNNING":
|
||||
job_status = JobStatus.running
|
||||
elif desc.status.name == "COMPLETED":
|
||||
job_status = JobStatus.completed
|
||||
elif desc.status.name == "FAILED":
|
||||
job_status = JobStatus.failed
|
||||
elif desc.status.name == "CANCELED":
|
||||
job_status = JobStatus.canceled
|
||||
# elif desc.status.name == "TERMINATED":
|
||||
# job_status = JobStatus.terminated
|
||||
# elif desc.status.name == "TIMED_OUT":
|
||||
# job_status = JobStatus.timed_out
|
||||
# elif desc.status.name == "CONTINUED_AS_NEW":
|
||||
# return WorkflowStatus.CONTINUED_AS_NEW
|
||||
# else:
|
||||
# return WorkflowStatus.UNKNOWN
|
||||
job.status = job_status
|
||||
return Run.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
@@ -238,6 +238,10 @@ class Settings(BaseSettings):
|
||||
redis_host: Optional[str] = Field(default=None, description="Host for Redis instance")
|
||||
redis_port: Optional[int] = Field(default=6379, description="Port for Redis instance")
|
||||
|
||||
temporal_api_key: Optional[str] = Field(default=None, description="API key for Temporal instance")
|
||||
temporal_namespace: Optional[str] = Field(default=None, description="Namespace for Temporal instance")
|
||||
temporal_endpoint: Optional[str] = Field(default=None, description="Endpoint for Temporal instance")
|
||||
|
||||
plugin_register: Optional[str] = None
|
||||
|
||||
# multi agent settings
|
||||
|
||||
Reference in New Issue
Block a user