diff --git a/alembic/versions/567e9fe06270_create_new_runs_table_and_remove_legacy_.py b/alembic/versions/567e9fe06270_create_new_runs_table_and_remove_legacy_.py new file mode 100644 index 00000000..aa4cf907 --- /dev/null +++ b/alembic/versions/567e9fe06270_create_new_runs_table_and_remove_legacy_.py @@ -0,0 +1,128 @@ +"""create new runs table and remove legacy tables + +Revision ID: 567e9fe06270 +Revises: 3d2e9fb40a3c +Create Date: 2025-09-22 15:22:28.651178 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "567e9fe06270" +down_revision: Union[str, None] = "3d2e9fb40a3c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "runs", + sa.Column("id", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column("completed_at", sa.DateTime(), nullable=True), + sa.Column("stop_reason", sa.String(), nullable=True), + sa.Column("background", sa.Boolean(), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("request_config", sa.JSON(), nullable=True), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("callback_url", sa.String(), nullable=True), + sa.Column("callback_sent_at", sa.DateTime(), nullable=True), + sa.Column("callback_status_code", sa.Integer(), nullable=True), + sa.Column("callback_error", sa.String(), nullable=True), + sa.Column("ttft_ns", sa.BigInteger(), nullable=True), + sa.Column("total_duration_ns", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("project_id", sa.String(), nullable=True), + sa.Column("base_template_id", sa.String(), nullable=True), + sa.Column("template_id", sa.String(), nullable=True), + sa.Column("deployment_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["agent_id"], + ["agents.id"], + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_runs_agent_id", "runs", ["agent_id"], unique=False) + op.create_index("ix_runs_created_at", "runs", ["created_at", "id"], unique=False) + op.create_index("ix_runs_organization_id", "runs", ["organization_id"], unique=False) + op.drop_index(op.f("ix_agents_runs_agent_id_run_id"), table_name="agents_runs") + op.drop_index(op.f("ix_agents_runs_run_id_agent_id"), table_name="agents_runs") + op.drop_table("agents_runs") + op.drop_table("job_messages") + op.add_column("messages", sa.Column("run_id", sa.String(), nullable=True)) + op.create_foreign_key("fk_messages_run_id", "messages", "runs", ["run_id"], ["id"], ondelete="SET NULL") + op.add_column("step_metrics", sa.Column("run_id", sa.String(), nullable=True)) + op.drop_constraint(op.f("step_metrics_job_id_fkey"), "step_metrics", type_="foreignkey") + op.create_foreign_key("fk_step_metrics_run_id", "step_metrics", "runs", ["run_id"], ["id"], ondelete="SET NULL") + op.drop_column("step_metrics", "job_id") + op.add_column("steps", sa.Column("run_id", sa.String(), nullable=True)) + op.drop_index(op.f("ix_steps_job_id"), table_name="steps") + op.create_index("ix_steps_run_id", "steps", ["run_id"], unique=False) + op.drop_constraint(op.f("fk_steps_job_id"), "steps", type_="foreignkey") + op.create_foreign_key("fk_steps_run_id", "steps", "runs", ["run_id"], ["id"], ondelete="SET NULL") + op.drop_column("steps", "job_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("job_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint("fk_steps_run_id", "steps", type_="foreignkey") + op.create_foreign_key(op.f("fk_steps_job_id"), "steps", "jobs", ["job_id"], ["id"], ondelete="SET NULL") + op.drop_index("ix_steps_run_id", table_name="steps") + op.create_index(op.f("ix_steps_job_id"), "steps", ["job_id"], unique=False) + op.drop_column("steps", "run_id") + op.add_column("step_metrics", sa.Column("job_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint("fk_step_metrics_run_id", "step_metrics", type_="foreignkey") + op.create_foreign_key(op.f("step_metrics_job_id_fkey"), "step_metrics", "jobs", ["job_id"], ["id"], ondelete="SET NULL") + op.drop_column("step_metrics", "run_id") + op.drop_constraint("fk_messages_run_id", "messages", type_="foreignkey") + op.drop_column("messages", "run_id") + op.create_table( + "job_messages", + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("job_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("message_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False), + sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], name=op.f("fk_job_messages_job_id"), ondelete="CASCADE"), + sa.ForeignKeyConstraint(["message_id"], ["messages.id"], name=op.f("fk_job_messages_message_id"), ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id", name=op.f("pk_job_messages")), + sa.UniqueConstraint( + "job_id", "message_id", name=op.f("unique_job_message"), postgresql_include=[], postgresql_nulls_not_distinct=False + ), + ) + op.create_table( + "agents_runs", + sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("run_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], name=op.f("agents_runs_agent_id_fkey")), + sa.ForeignKeyConstraint(["run_id"], ["jobs.id"], name=op.f("agents_runs_run_id_fkey")), + sa.PrimaryKeyConstraint("agent_id", "run_id", name=op.f("unique_agent_run")), + ) + op.create_index(op.f("ix_agents_runs_run_id_agent_id"), "agents_runs", ["run_id", "agent_id"], unique=False) + op.create_index(op.f("ix_agents_runs_agent_id_run_id"), "agents_runs", ["agent_id", "run_id"], unique=False) + op.drop_index("ix_runs_organization_id", table_name="runs") + op.drop_index("ix_runs_created_at", table_name="runs") + op.drop_index("ix_runs_agent_id", table_name="runs") + op.drop_table("runs") + # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index aa200772..f5a5a9e7 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -6071,7 +6071,7 @@ "post": { "tags": ["agents"], "summary": "Send Message Async", - "description": "Asynchronously process a user message and return a run object.\nThe actual processing happens in the background, and the status can be checked using the run ID.\n\nThis is \"asynchronous\" in the sense that it's a background job and explicitly must be fetched by the run ID.\nThis is more like `send_message_job`", + "description": "Asynchronously process a user message and return a run object.\nThe actual processing happens in the background, and the status can be checked using the run ID.\n\nThis is \"asynchronous\" in the sense that it's a background run and explicitly must be fetched by the run ID.", "operationId": "create_agent_message_async", "parameters": [ { @@ -29299,6 +29299,18 @@ "title": "Step Id", "description": "The id of the step that this message was created in." }, + "run_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Run Id", + "description": "The id of the run that this message was created in." + }, "otid": { "anyOf": [ { @@ -31026,53 +31038,23 @@ }, "Run": { "properties": { - "created_by_id": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Created By Id", - "description": "The id of the user that made this object." + "id": { + "type": "string", + "pattern": "^(job|run)-[a-fA-F0-9]{8}", + "title": "Id", + "description": "The human-friendly ID of the Run", + "examples": ["run-123e4567-e89b-12d3-a456-426614174000"] }, - "last_updated_by_id": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Last Updated By Id", - "description": "The id of the user that made this object." + "status": { + "$ref": "#/components/schemas/RunStatus", + "description": "The current status of the run.", + "default": "created" }, "created_at": { "type": "string", "format": "date-time", "title": "Created At", - "description": "The unix timestamp of when the job was created." - }, - "updated_at": { - "anyOf": [ - { - "type": "string", - "format": "date-time" - }, - { - "type": "null" - } - ], - "title": "Updated At", - "description": "The timestamp when the object was last updated." - }, - "status": { - "$ref": "#/components/schemas/JobStatus", - "description": "The status of the job.", - "default": "created" + "description": "The timestamp when the run was created." }, "completed_at": { "anyOf": [ @@ -31085,18 +31067,24 @@ } ], "title": "Completed At", - "description": "The unix timestamp of when the job was completed." + "description": "The timestamp when the run was completed." }, - "stop_reason": { + "agent_id": { + "type": "string", + "title": "Agent Id", + "description": "The unique identifier of the agent associated with the run." + }, + "background": { "anyOf": [ { - "$ref": "#/components/schemas/StopReasonType" + "type": "boolean" }, { "type": "null" } ], - "description": "The reason why the run was stopped." + "title": "Background", + "description": "Whether the run was created in background mode." }, "metadata": { "anyOf": [ @@ -31109,35 +31097,29 @@ } ], "title": "Metadata", - "description": "The metadata of the job." + "description": "Additional metadata for the run." }, - "job_type": { - "$ref": "#/components/schemas/JobType", - "default": "run" - }, - "background": { + "request_config": { "anyOf": [ { - "type": "boolean" + "$ref": "#/components/schemas/LettaRequestConfig" }, { "type": "null" } ], - "title": "Background", - "description": "Whether the job was created in background mode." + "description": "The request configuration for the run." }, - "agent_id": { + "stop_reason": { "anyOf": [ { - "type": "string" + "$ref": "#/components/schemas/StopReasonType" }, { "type": "null" } ], - "title": "Agent Id", - "description": "The agent associated with this job/run." + "description": "The reason why the run was stopped." }, "callback_url": { "anyOf": [ @@ -31149,7 +31131,7 @@ } ], "title": "Callback Url", - "description": "If set, POST to this URL when the job completes." + "description": "If set, POST to this URL when the run completes." }, "callback_sent_at": { "anyOf": [ @@ -31211,30 +31193,19 @@ ], "title": "Total Duration Ns", "description": "Total run duration in nanoseconds" - }, - "id": { - "type": "string", - "pattern": "^(job|run)-[a-fA-F0-9]{8}", - "title": "Id", - "description": "The human-friendly ID of the Run", - "examples": ["run-123e4567-e89b-12d3-a456-426614174000"] - }, - "request_config": { - "anyOf": [ - { - "$ref": "#/components/schemas/LettaRequestConfig" - }, - { - "type": "null" - } - ], - "description": "The request configuration for the run." } }, "additionalProperties": false, "type": "object", + "required": ["agent_id"], "title": "Run", - "description": "Representation of a run, which is a job with a 'run' prefix in its ID.\nInherits all fields and behavior from Job except for the ID prefix.\n\nParameters:\n id (str): The unique identifier of the run (prefixed with 'run-').\n status (JobStatus): The status of the run.\n created_at (datetime): The unix timestamp of when the run was created.\n completed_at (datetime): The unix timestamp of when the run was completed.\n user_id (str): The unique identifier of the user associated with the run." + "description": "Representation of a run - a conversation or processing session for an agent.\nRuns track when agents process messages and maintain the relationship between agents, steps, and messages.\n\nParameters:\n id (str): The unique identifier of the run (prefixed with 'run-').\n status (JobStatus): The current status of the run.\n created_at (datetime): The timestamp when the run was created.\n completed_at (datetime): The timestamp when the run was completed.\n agent_id (str): The unique identifier of the agent associated with the run.\n stop_reason (StopReasonType): The reason why the run was stopped.\n background (bool): Whether the run was created in background mode.\n metadata (dict): Additional metadata for the run.\n request_config (LettaRequestConfig): The request configuration for the run." + }, + "RunStatus": { + "type": "string", + "enum": ["created", "running", "completed", "failed", "cancelled"], + "title": "RunStatus", + "description": "Status of the run." }, "SSEServerConfig": { "properties": { @@ -32131,7 +32102,7 @@ "title": "Provider Id", "description": "The unique identifier of the provider that was configured for this step" }, - "job_id": { + "run_id": { "anyOf": [ { "type": "string" @@ -32140,8 +32111,8 @@ "type": "null" } ], - "title": "Job Id", - "description": "The unique identifier of the job that this step belongs to. Only included for async calls." + "title": "Run Id", + "description": "The unique identifier of the run that this step belongs to. Only included for async calls." }, "agent_id": { "anyOf": [ @@ -32405,7 +32376,7 @@ "title": "Provider Id", "description": "The unique identifier of the provider." }, - "job_id": { + "run_id": { "anyOf": [ { "type": "string" @@ -32414,8 +32385,8 @@ "type": "null" } ], - "title": "Job Id", - "description": "The unique identifier of the job." + "title": "Run Id", + "description": "The unique identifier of the run." }, "agent_id": { "anyOf": [ diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 228de580..4fb0c105 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -300,7 +300,7 @@ class LettaAgent(BaseAgent): context_window_limit=agent_state.llm_config.context_window, usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), provider_id=None, - job_id=self.current_run_id if self.current_run_id else None, + run_id=self.current_run_id if self.current_run_id else None, step_id=step_id, project_id=agent_state.project_id, status=StepStatus.PENDING, @@ -644,7 +644,7 @@ class LettaAgent(BaseAgent): context_window_limit=agent_state.llm_config.context_window, usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), provider_id=None, - job_id=run_id if run_id else self.current_run_id, + run_id=run_id if run_id else self.current_run_id, step_id=step_id, project_id=agent_state.project_id, status=StepStatus.PENDING, @@ -768,7 +768,7 @@ class LettaAgent(BaseAgent): step_id=step_id, agent_state=agent_state, step_metrics=step_metrics, - job_id=run_id if run_id else self.current_run_id, + run_id=run_id if run_id else self.current_run_id, ) except Exception as e: @@ -989,7 +989,7 @@ class LettaAgent(BaseAgent): context_window_limit=agent_state.llm_config.context_window, usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), provider_id=None, - job_id=self.current_run_id if self.current_run_id else None, + run_id=self.current_run_id if self.current_run_id else None, step_id=step_id, project_id=agent_state.project_id, status=StepStatus.PENDING, @@ -1676,6 +1676,7 @@ class LettaAgent(BaseAgent): reasoning_content=None, pre_computed_assistant_message_id=None, step_id=step_id, + run_id=self.current_run_id, is_approval_response=True, ) messages_to_persist = (initial_messages or []) + tool_call_messages @@ -1786,6 +1787,7 @@ class LettaAgent(BaseAgent): reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=self.current_run_id, is_approval_response=is_approval or is_denial, ) messages_to_persist = (initial_messages or []) + tool_call_messages @@ -1794,13 +1796,6 @@ class LettaAgent(BaseAgent): messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id ) - if run_id: - await self.job_manager.add_messages_to_job_async( - job_id=run_id, - message_ids=[m.id for m in persisted_messages if m.role != "user"], - actor=self.actor, - ) - return persisted_messages, continue_stepping, stop_reason def _decide_continuation( diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index fada233f..380c9332 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -1,7 +1,7 @@ import asyncio import uuid from datetime import datetime -from typing import AsyncGenerator, Tuple +from typing import AsyncGenerator, Optional, Tuple from opentelemetry.trace import Span @@ -31,7 +31,7 @@ from letta.log import get_logger from letta.otel.tracing import log_event, trace_method, tracer from letta.prompts.prompt_generator import PromptGenerator from letta.schemas.agent import AgentState, UpdateAgent -from letta.schemas.enums import AgentType, JobStatus, MessageStreamStatus, StepStatus +from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus, StepStatus from letta.schemas.letta_message import LettaMessage, MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse @@ -51,9 +51,9 @@ from letta.services.agent_manager import AgentManager from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema -from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.run_manager import RunManager from letta.services.step_manager import StepManager from letta.services.summarizer.enums import SummarizationMode from letta.services.summarizer.summarizer import Summarizer @@ -93,7 +93,7 @@ class LettaAgentV2(BaseAgentV2): self.agent_manager = AgentManager() self.archive_manager = ArchiveManager() self.block_manager = BlockManager() - self.job_manager = JobManager() + self.run_manager = RunManager() self.message_manager = MessageManager() self.passage_manager = PassageManager() self.step_manager = StepManager() @@ -145,9 +145,11 @@ class LettaAgentV2(BaseAgentV2): input_messages, self.agent_state, self.message_manager, self.actor ) response = self._step( + run_id=None, messages=in_context_messages + input_messages_to_persist, llm_adapter=LettaLLMRequestAdapter(llm_client=self.llm_client, llm_config=self.agent_state.llm_config), dry_run=True, + enforce_run_id_set=False, ) async for chunk in response: request = chunk # First chunk contains request data @@ -339,13 +341,14 @@ class LettaAgentV2(BaseAgentV2): self, messages: list[Message], llm_adapter: LettaLLMAdapter, + run_id: Optional[str], input_messages_to_persist: list[Message] | None = None, - run_id: str | None = None, use_assistant_message: bool = True, include_return_message_types: list[MessageType] | None = None, request_start_timestamp_ns: int | None = None, remaining_turns: int = -1, dry_run: bool = False, + enforce_run_id_set: bool = True, ) -> AsyncGenerator[LettaMessage | dict, None]: """ Execute a single agent step (one LLM call and tool execution). @@ -368,6 +371,9 @@ class LettaAgentV2(BaseAgentV2): Yields: LettaMessage or dict: Chunks for streaming mode, or request data for dry_run """ + if enforce_run_id_set and run_id is None: + raise AssertionError("run_id is required when enforce_run_id_set is True") + step_progression = StepProgression.START # TODO(@caren): clean this up tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = ( @@ -465,6 +471,13 @@ class LettaAgentV2(BaseAgentV2): self.stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) raise ValueError("No tool calls found in response, model must make a tool call") + # TODO: how should be associate input messages with runs? + ## Set run_id on input messages before persisting + # if input_messages_to_persist and run_id: + # for message in input_messages_to_persist: + # if message.run_id is None: + # message.run_id = run_id + persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( tool_call or llm_adapter.tool_call, [tool["name"] for tool in valid_tools], @@ -566,6 +579,7 @@ class LettaAgentV2(BaseAgentV2): for message in input_messages_to_persist: message.is_err = True message.step_id = step_id + message.run_id = run_id await self.message_manager.create_many_messages_async( input_messages_to_persist, actor=self.actor, @@ -609,8 +623,8 @@ class LettaAgentV2(BaseAgentV2): @trace_method async def _check_run_cancellation(self, run_id) -> bool: try: - job = await self.job_manager.get_job_by_id_async(job_id=run_id, actor=self.actor) - return job.status == JobStatus.cancelled + run = await self.run_manager.get_run_by_id(run_id=run_id, actor=self.actor) + return run.status == RunStatus.cancelled except Exception as e: # Log the error but don't fail the execution self.logger.warning(f"Failed to check job cancellation status for job {run_id}: {e}") @@ -783,7 +797,7 @@ class LettaAgentV2(BaseAgentV2): context_window_limit=self.agent_state.llm_config.context_window, usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), provider_id=None, - job_id=run_id, + run_id=run_id, step_id=step_id, project_id=self.agent_state.project_id, status=StepStatus.PENDING, @@ -887,11 +901,14 @@ class LettaAgentV2(BaseAgentV2): pre_computed_assistant_message_id=None, step_id=step_id, is_approval_response=True, + run_id=run_id, ) messages_to_persist = (initial_messages or []) + tool_call_messages + persisted_messages = await self.message_manager.create_many_messages_async( messages_to_persist, actor=self.actor, + run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id, ) @@ -925,6 +942,7 @@ class LettaAgentV2(BaseAgentV2): reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=run_id, ) messages_to_persist = (initial_messages or []) + [approval_message] continue_stepping = False @@ -1000,21 +1018,15 @@ class LettaAgentV2(BaseAgentV2): reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=run_id, is_approval_response=is_approval or is_denial, ) messages_to_persist = (initial_messages or []) + tool_call_messages persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id + messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id ) - if run_id: - await self.job_manager.add_messages_to_job_async( - job_id=run_id, - message_ids=[m.id for m in persisted_messages if m.role != "user"], - actor=self.actor, - ) - return persisted_messages, continue_stepping, stop_reason @trace_method @@ -1072,6 +1084,7 @@ class LettaAgentV2(BaseAgentV2): agent_state: AgentState, agent_step_span: Span | None = None, step_id: str | None = None, + run_id: str = None, ) -> "ToolExecutionResult": """ Executes a tool and returns the ToolExecutionResult. @@ -1097,9 +1110,9 @@ class LettaAgentV2(BaseAgentV2): tool_execution_manager = ToolExecutionManager( agent_state=agent_state, message_manager=self.message_manager, + run_manager=self.run_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, - job_manager=self.job_manager, passage_manager=self.passage_manager, sandbox_env_vars=sandbox_env_vars, actor=self.actor, @@ -1182,7 +1195,7 @@ class LettaAgentV2(BaseAgentV2): tool_execution_ns=step_metrics.tool_execution_ns, step_ns=step_metrics.step_ns, agent_id=self.agent_state.id, - job_id=run_id, + run_id=run_id, project_id=self.agent_state.project_id, template_id=self.agent_state.template_id, base_template_id=self.agent_state.base_template_id, @@ -1206,15 +1219,15 @@ class LettaAgentV2(BaseAgentV2): if request_span: request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)}) await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns)) - if settings.track_agent_run and run_id: - await self.job_manager.record_response_duration(run_id, duration_ns, self.actor) - await self.job_manager.safe_update_job_status_async( - job_id=run_id, - new_status=JobStatus.failed if is_error else JobStatus.completed, - actor=self.actor, - stop_reason=self.stop_reason.stop_reason if self.stop_reason else StopReasonType.error, - metadata=job_update_metadata, - ) + # if settings.track_agent_run and run_id: + # await self.job_manager.record_response_duration(run_id, duration_ns, self.actor) + # await self.job_manager.safe_update_job_status_async( + # job_id=run_id, + # new_status=JobStatus.failed if is_error else JobStatus.completed, + # actor=self.actor, + # stop_reason=self.stop_reason.stop_reason if self.stop_reason else StopReasonType.error, + # metadata=job_update_metadata, + # ) if request_span: request_span.end() diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 14dd768b..8aab1e98 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -472,9 +472,11 @@ class LettaAgentV3(LettaAgentV2): for message in input_messages_to_persist: message.is_err = True message.step_id = step_id + message.run_id = run_id await self.message_manager.create_many_messages_async( input_messages_to_persist, actor=self.actor, + run_id=run_id, project_id=self.agent_state.project_id, template_id=self.agent_state.template_id, ) @@ -555,14 +557,23 @@ class LettaAgentV3(LettaAgentV2): reasoning_content=None, pre_computed_assistant_message_id=None, step_id=step_id, + run_id=run_id, is_approval_response=True, force_set_request_heartbeat=False, add_heartbeat_on_continue=False, ) messages_to_persist = (initial_messages or []) + tool_call_messages + + # Set run_id on all messages before persisting + for message in messages_to_persist: + if message.run_id is None: + message.run_id = run_id + print("MESSSAGE RUN ID", message.run_id, run_id) + persisted_messages = await self.message_manager.create_many_messages_async( messages_to_persist, actor=self.actor, + run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id, ) @@ -603,6 +614,7 @@ class LettaAgentV3(LettaAgentV2): reasoning_content=content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=run_id, is_approval_response=is_approval or is_denial, force_set_request_heartbeat=False, add_heartbeat_on_continue=False, @@ -642,6 +654,7 @@ class LettaAgentV3(LettaAgentV2): reasoning_content=content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=run_id, ) messages_to_persist = (initial_messages or []) + [approval_message] continue_stepping = False @@ -719,22 +732,22 @@ class LettaAgentV3(LettaAgentV2): reasoning_content=content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=run_id, is_approval_response=is_approval or is_denial, force_set_request_heartbeat=False, add_heartbeat_on_continue=False, ) messages_to_persist = (initial_messages or []) + tool_call_messages - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id - ) + # Set run_id on all messages before persisting + for message in messages_to_persist: + if message.run_id is None: + message.run_id = run_id + print("MESSSAGE RUN ID", message.run_id, run_id) - if run_id: - await self.job_manager.add_messages_to_job_async( - job_id=run_id, - message_ids=[m.id for m in persisted_messages if m.role != "user"], - actor=self.actor, - ) + persisted_messages = await self.message_manager.create_many_messages_async( + messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id + ) return persisted_messages, continue_stepping, stop_reason diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index f320986e..54946138 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -504,7 +504,7 @@ class VoiceAgent(BaseAgent): keyword_results = {} if convo_keyword_queries: for keyword in convo_keyword_queries: - messages = await self.message_manager.list_messages_for_agent_async( + messages = await self.message_manager.list_messages( agent_id=self.agent_id, actor=self.actor, query_text=keyword, diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 879241c2..563600b7 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -268,7 +268,7 @@ class SleeptimeMultiAgentV2(BaseAgent): prior_messages = [] if self.group.sleeptime_agent_frequency: try: - prior_messages = await self.message_manager.list_messages_for_agent_async( + prior_messages = await self.message_manager.list_messages( agent_id=foreground_agent_id, actor=self.actor, after=last_processed_message_id, diff --git a/letta/groups/sleeptime_multi_agent_v3.py b/letta/groups/sleeptime_multi_agent_v3.py index d8c49399..4bc0646a 100644 --- a/letta/groups/sleeptime_multi_agent_v3.py +++ b/letta/groups/sleeptime_multi_agent_v3.py @@ -7,14 +7,14 @@ from letta.constants import DEFAULT_MAX_STEPS from letta.groups.helpers import stringify_message from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState -from letta.schemas.enums import JobStatus +from letta.schemas.enums import JobStatus, RunStatus from letta.schemas.group import Group, ManagerType from letta.schemas.job import JobUpdate from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate -from letta.schemas.run import Run +from letta.schemas.run import Run, RunUpdate from letta.schemas.user import User from letta.services.group_manager import GroupManager from letta.utils import safe_create_task @@ -134,14 +134,14 @@ class SleeptimeMultiAgentV3(LettaAgentV2): use_assistant_message: bool = True, ) -> str: run = Run( - user_id=self.actor.id, - status=JobStatus.created, + agent_id=sleeptime_agent_id, + status=RunStatus.created, metadata={ - "job_type": "sleeptime_agent_send_message_async", # is this right? + "run_type": "sleeptime_agent_send_message_async", # is this right? "agent_id": sleeptime_agent_id, }, ) - run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor) + run = await self.run_manager.create_run(pydantic_run=run, actor=self.actor) safe_create_task( self._participant_agent_step( @@ -167,15 +167,15 @@ class SleeptimeMultiAgentV3(LettaAgentV2): use_assistant_message: bool = True, ) -> LettaResponse: try: - # Update job status - job_update = JobUpdate(status=JobStatus.running) - await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) + # Update run status + run_update = RunUpdate(status=RunStatus.running) + await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor) # Create conversation transcript prior_messages = [] if self.group.sleeptime_agent_frequency: try: - prior_messages = await self.message_manager.list_messages_for_agent_async( + prior_messages = await self.message_manager.list_messages( agent_id=foreground_agent_id, actor=self.actor, after=last_processed_message_id, @@ -212,22 +212,22 @@ class SleeptimeMultiAgentV3(LettaAgentV2): use_assistant_message=use_assistant_message, ) - # Update job status - job_update = JobUpdate( - status=JobStatus.completed, + # Update run status + run_update = RunUpdate( + status=RunStatus.completed, completed_at=datetime.now(timezone.utc).replace(tzinfo=None), metadata={ "result": result.model_dump(mode="json"), "agent_id": sleeptime_agent_state.id, }, ) - await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) + await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor) return result except Exception as e: - job_update = JobUpdate( - status=JobStatus.failed, + run_update = RunUpdate( + status=RunStatus.failed, completed_at=datetime.now(timezone.utc).replace(tzinfo=None), metadata={"error": str(e)}, ) - await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) + await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor) raise diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 1121197e..72d5f056 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,5 +1,4 @@ from letta.orm.agent import Agent -from letta.orm.agents_runs import AgentsRuns from letta.orm.agents_tags import AgentsTags from letta.orm.archive import Archive from letta.orm.archives_agents import ArchivesAgents @@ -16,7 +15,6 @@ from letta.orm.identities_agents import IdentitiesAgents from letta.orm.identities_blocks import IdentitiesBlocks from letta.orm.identity import Identity from letta.orm.job import Job -from letta.orm.job_messages import JobMessage from letta.orm.llm_batch_items import LLMBatchItem from letta.orm.llm_batch_job import LLMBatchJob from letta.orm.mcp_oauth import MCPOAuth @@ -28,6 +26,7 @@ from letta.orm.passage_tag import PassageTag from letta.orm.prompt import Prompt from letta.orm.provider import Provider from letta.orm.provider_trace import ProviderTrace +from letta.orm.run import Run from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/agent.py b/letta/orm/agent.py index bb12d53a..fb6bf782 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from letta.orm.files_agents import FileAgent from letta.orm.identity import Identity from letta.orm.organization import Organization + from letta.orm.run import Run from letta.orm.source import Source from letta.orm.tool import Tool @@ -133,8 +134,8 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin lazy="selectin", doc="Tags associated with the agent.", ) - runs: Mapped[List["AgentsRuns"]] = relationship( - "AgentsRuns", + runs: Mapped[List["Run"]] = relationship( + "Run", back_populates="agent", cascade="all, delete-orphan", lazy="selectin", diff --git a/letta/orm/agents_runs.py b/letta/orm/agents_runs.py deleted file mode 100644 index b8cfe724..00000000 --- a/letta/orm/agents_runs.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlalchemy import ForeignKey, Index, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from letta.orm.base import Base - -if TYPE_CHECKING: - from letta.orm.agent import Agent - from letta.orm.job import Job - - -class AgentsRuns(Base): - __tablename__ = "agents_runs" - __table_args__ = ( - UniqueConstraint("agent_id", "run_id", name="unique_agent_run"), - Index("ix_agents_runs_agent_id_run_id", "agent_id", "run_id"), - Index("ix_agents_runs_run_id_agent_id", "run_id", "agent_id"), - ) - - agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) - run_id: Mapped[str] = mapped_column(String, ForeignKey("jobs.id"), primary_key=True) - - # relationships - agent: Mapped["Agent"] = relationship("Agent", back_populates="runs") - run: Mapped["Job"] = relationship("Job", back_populates="agent") diff --git a/letta/orm/job.py b/letta/orm/job.py index 048ea514..9d46a01a 100644 --- a/letta/orm/job.py +++ b/letta/orm/job.py @@ -11,11 +11,8 @@ from letta.schemas.job import Job as PydanticJob, LettaRequestConfig from letta.schemas.letta_stop_reason import StopReasonType if TYPE_CHECKING: - from letta.orm.agents_runs import AgentsRuns - from letta.orm.job_messages import JobMessage from letta.orm.message import Message from letta.orm.organization import Organization - from letta.orm.step import Step from letta.orm.user import User @@ -62,11 +59,8 @@ class Job(SqlalchemyBase, UserMixin): # relationships user: Mapped["User"] = relationship("User", back_populates="jobs") - job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan") - steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update") # organization relationship (nullable for backward compatibility) organization: Mapped[Optional["Organization"]] = relationship("Organization", back_populates="jobs") - agent: Mapped[List["AgentsRuns"]] = relationship("AgentsRuns", back_populates="run", cascade="all, delete-orphan") @property def messages(self) -> List["Message"]: diff --git a/letta/orm/job_messages.py b/letta/orm/job_messages.py deleted file mode 100644 index 063febfc..00000000 --- a/letta/orm/job_messages.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlalchemy import ForeignKey, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from letta.orm.sqlalchemy_base import SqlalchemyBase - -if TYPE_CHECKING: - from letta.orm.job import Job - from letta.orm.message import Message - - -class JobMessage(SqlalchemyBase): - """Tracks messages that were created during job execution.""" - - __tablename__ = "job_messages" - __table_args__ = (UniqueConstraint("job_id", "message_id", name="unique_job_message"),) - - id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the job message") - job_id: Mapped[str] = mapped_column( - ForeignKey("jobs.id", ondelete="CASCADE"), - nullable=False, # A job message must belong to a job - doc="ID of the job that created the message", - ) - message_id: Mapped[str] = mapped_column( - ForeignKey("messages.id", ondelete="CASCADE"), - nullable=False, # A job message must have a message - doc="ID of the message created by the job", - ) - - # Relationships - job: Mapped["Job"] = relationship("Job", back_populates="job_messages") - message: Mapped["Message"] = relationship("Message", back_populates="job_message") diff --git a/letta/orm/message.py b/letta/orm/message.py index 76b9a8c0..71028cb4 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -35,6 +35,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): step_id: Mapped[Optional[str]] = mapped_column( ForeignKey("steps.id", ondelete="SET NULL"), nullable=True, doc="ID of the step that this message belongs to" ) + run_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("runs.id", ondelete="SET NULL"), nullable=True, doc="ID of the run that this message belongs to" + ) otid: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The offline threading ID associated with this message") tool_returns: Mapped[List[ToolReturn]] = mapped_column( ToolReturnColumn, nullable=True, doc="Tool execution return information for prior tool calls" @@ -68,11 +71,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): # Relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise") step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin") - - # Job relationship - job_message: Mapped[Optional["JobMessage"]] = relationship( - "JobMessage", back_populates="message", uselist=False, cascade="all, delete-orphan", single_parent=True - ) + run: Mapped["Run"] = relationship("Run", back_populates="messages", lazy="selectin") @property def job(self) -> Optional["Job"]: diff --git a/letta/orm/organization.py b/letta/orm/organization.py index bb32a9d6..8e1c0f6d 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from letta.orm.passage import ArchivalPassage, SourcePassage from letta.orm.passage_tag import PassageTag from letta.orm.provider import Provider + from letta.orm.run import Run from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.tool import Tool from letta.orm.user import User @@ -68,3 +69,4 @@ class Organization(SqlalchemyBase): "LLMBatchItem", back_populates="organization", cascade="all, delete-orphan" ) jobs: Mapped[List["Job"]] = relationship("Job", back_populates="organization", cascade="all, delete-orphan") + runs: Mapped[List["Run"]] = relationship("Run", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/run.py b/letta/orm/run.py new file mode 100644 index 00000000..45d3d956 --- /dev/null +++ b/letta/orm/run.py @@ -0,0 +1,71 @@ +import uuid +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import JSON, BigInteger, Boolean, DateTime, ForeignKey, Index, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.enums import RunStatus +from letta.schemas.job import LettaRequestConfig +from letta.schemas.letta_stop_reason import StopReasonType +from letta.schemas.run import Run as PydanticRun + +if TYPE_CHECKING: + from letta.orm.agent import Agent + from letta.orm.message import Message + from letta.orm.organization import Organization + from letta.orm.step import Step + + +class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin): + """Runs are created when agents process messages and represent a conversation or processing session. + Unlike Jobs, Runs are specifically tied to agent interactions and message processing. + """ + + __tablename__ = "runs" + __pydantic_model__ = PydanticRun + __table_args__ = ( + Index("ix_runs_created_at", "created_at", "id"), + Index("ix_runs_agent_id", "agent_id"), + Index("ix_runs_organization_id", "organization_id"), + ) + + # Generate run ID with run- prefix + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"run-{uuid.uuid4()}") + + # Core run fields + status: Mapped[RunStatus] = mapped_column(String, default=RunStatus.created, doc="The current status of the run.") + completed_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="The unix timestamp of when the run was completed.") + stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(String, nullable=True, doc="The reason why the run was stopped.") + background: Mapped[Optional[bool]] = mapped_column( + Boolean, nullable=True, default=False, doc="Whether the run was created in background mode." + ) + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, doc="The metadata of the run.") + request_config: Mapped[Optional[LettaRequestConfig]] = mapped_column( + JSON, nullable=True, doc="The request configuration for the run, stored as JSON." + ) + + # Agent relationship - A run belongs to one agent + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), nullable=False, doc="The agent that owns this run.") + + # Callback related columns + callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after run completion.") + callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.") + callback_status_code: Mapped[Optional[int]] = mapped_column(nullable=True, doc="HTTP status code returned by the callback endpoint.") + callback_error: Mapped[Optional[str]] = mapped_column( + nullable=True, doc="Optional error message from attempting to POST the callback endpoint." + ) + + # Timing metrics (in nanoseconds for precision) + ttft_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, doc="Time to first token in nanoseconds") + total_duration_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, doc="Total run duration in nanoseconds") + + # Relationships + agent: Mapped["Agent"] = relationship("Agent", back_populates="runs") + organization: Mapped[Optional["Organization"]] = relationship("Organization", back_populates="runs") + + # Steps that are part of this run + steps: Mapped[List["Step"]] = relationship("Step", back_populates="run", cascade="all, delete-orphan") + messages: Mapped[List["Message"]] = relationship("Message", back_populates="run", cascade="all, delete-orphan") diff --git a/letta/orm/step.py b/letta/orm/step.py index 444ab403..7a76649b 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -10,10 +10,10 @@ from letta.schemas.enums import StepStatus from letta.schemas.step import Step as PydanticStep if TYPE_CHECKING: - from letta.orm.job import Job from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.provider import Provider + from letta.orm.run import Run from letta.orm.step_metrics import StepMetrics @@ -22,7 +22,7 @@ class Step(SqlalchemyBase, ProjectMixin): __tablename__ = "steps" __pydantic_model__ = PydanticStep - __table_args__ = (Index("ix_steps_job_id", "job_id"),) + __table_args__ = (Index("ix_steps_run_id", "run_id"),) id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}") origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.") @@ -36,8 +36,8 @@ class Step(SqlalchemyBase, ProjectMixin): nullable=True, doc="The unique identifier of the provider that was configured for this step", ) - job_id: Mapped[Optional[str]] = mapped_column( - ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step" + run_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("runs.id", ondelete="SET NULL"), nullable=True, doc="The unique identifier of the run that this step belongs to" ) agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.") @@ -69,7 +69,7 @@ class Step(SqlalchemyBase, ProjectMixin): # Relationships (foreign keys) organization: Mapped[Optional["Organization"]] = relationship("Organization") provider: Mapped[Optional["Provider"]] = relationship("Provider") - job: Mapped[Optional["Job"]] = relationship("Job", back_populates="steps") + run: Mapped[Optional["Run"]] = relationship("Run", back_populates="steps") # Relationships (backrefs) messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload") diff --git a/letta/orm/step_metrics.py b/letta/orm/step_metrics.py index 6f8f4114..941bef16 100644 --- a/letta/orm/step_metrics.py +++ b/letta/orm/step_metrics.py @@ -13,7 +13,7 @@ from letta.settings import DatabaseChoice, settings if TYPE_CHECKING: from letta.orm.agent import Agent - from letta.orm.job import Job + from letta.orm.run import Run from letta.orm.step import Step @@ -38,10 +38,10 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin): nullable=True, doc="The unique identifier of the provider", ) - job_id: Mapped[Optional[str]] = mapped_column( - ForeignKey("jobs.id", ondelete="SET NULL"), + run_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("runs.id", ondelete="SET NULL"), nullable=True, - doc="The unique identifier of the job", + doc="The unique identifier of the run", ) step_start_ns: Mapped[Optional[int]] = mapped_column( BigInteger, @@ -81,7 +81,7 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin): # Relationships (foreign keys) step: Mapped["Step"] = relationship("Step", back_populates="metrics", uselist=False) - job: Mapped[Optional["Job"]] = relationship("Job") + run: Mapped[Optional["Run"]] = relationship("Run") agent: Mapped[Optional["Agent"]] = relationship("Agent") def create( diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index 73477c2e..39ffa869 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -168,7 +168,7 @@ class AgentSchema(CreateAgent): per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit, ) - messages = await message_manager.list_messages_for_agent_async( + messages = await message_manager.list_messages( agent_id=agent_state.id, actor=actor, limit=50 ) # TODO: Expand to get more messages diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 8e099918..f6bb041a 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -51,6 +51,11 @@ class MessageRole(str, Enum): approval = "approval" +class MessageSourceType(str, Enum): + input = "input" # external input + output = "output" # internal output + + class OptionState(str, Enum): """Useful for kwargs that are bool + default option""" @@ -78,6 +83,18 @@ class JobStatus(StrEnum): return self in (JobStatus.completed, JobStatus.failed, JobStatus.cancelled, JobStatus.expired) +class RunStatus(StrEnum): + """ + Status of the run. + """ + + created = "created" + running = "running" + completed = "completed" + failed = "failed" + cancelled = "cancelled" + + class AgentStepStatus(str, Enum): """ Status of agent step. diff --git a/letta/schemas/job.py b/letta/schemas/job.py index 73c6e049..007381fd 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -1,8 +1,11 @@ from datetime import datetime -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from pydantic import BaseModel, ConfigDict, Field +if TYPE_CHECKING: + from letta.schemas.letta_request import LettaRequest + from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.helpers.datetime_helpers import get_utc_time from letta.schemas.enums import JobStatus, JobType @@ -112,3 +115,13 @@ class LettaRequestConfig(BaseModel): include_return_message_types: Optional[List[MessageType]] = Field( default=None, description="Only return specified message types in the response. If `None` (default) returns all messages." ) + + @classmethod + def from_letta_request(cls, request: "LettaRequest") -> "LettaRequestConfig": + """Create a LettaRequestConfig from a LettaRequest.""" + return cls( + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, + ) diff --git a/letta/schemas/letta_stop_reason.py b/letta/schemas/letta_stop_reason.py index 2e5f91a8..fe62d742 100644 --- a/letta/schemas/letta_stop_reason.py +++ b/letta/schemas/letta_stop_reason.py @@ -3,7 +3,7 @@ from typing import Literal from pydantic import BaseModel, Field -from letta.schemas.enums import JobStatus +from letta.schemas.enums import RunStatus class StopReasonType(str, Enum): @@ -19,14 +19,14 @@ class StopReasonType(str, Enum): requires_approval = "requires_approval" @property - def run_status(self) -> JobStatus: + def run_status(self) -> RunStatus: if self in ( StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule, StopReasonType.requires_approval, ): - return JobStatus.completed + return RunStatus.completed elif self in ( StopReasonType.error, StopReasonType.invalid_tool_call, @@ -34,9 +34,9 @@ class StopReasonType(str, Enum): StopReasonType.invalid_llm_response, StopReasonType.llm_api_error, ): - return JobStatus.failed + return RunStatus.failed elif self == StopReasonType.cancelled: - return JobStatus.cancelled + return RunStatus.cancelled else: raise ValueError("Unknown StopReasonType") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index c7bf6780..5f6ec826 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -194,6 +194,7 @@ class Message(BaseMessage): tool_call_id: Optional[str] = Field(default=None, description="The ID of the tool call. Only applicable for role tool.") # Extras step_id: Optional[str] = Field(default=None, description="The id of the step that this message was created in.") + run_id: Optional[str] = Field(default=None, description="The id of the run that this message was created in.") otid: Optional[str] = Field(default=None, description="The offline threading id associated with this message") tool_returns: Optional[List[ToolReturn]] = Field(default=None, description="Tool execution return information for prior tool calls") group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in") @@ -210,6 +211,13 @@ class Message(BaseMessage): # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") + # validate that run_id is set + # @model_validator(mode="after") + # def validate_run_id(self): + # if self.run_id is None: + # raise ValueError("Run ID is required") + # return self + @field_validator("role") @classmethod def validate_role(cls, v: str) -> str: @@ -323,6 +331,7 @@ class Message(BaseMessage): approve=self.approve, approval_request_id=self.approval_request_id, reason=self.denial_reason, + run_id=self.run_id, ) messages.append(approval_response_message) else: @@ -353,6 +362,7 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) ) else: @@ -367,6 +377,7 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) ) @@ -401,6 +412,7 @@ class Message(BaseMessage): otid=otid, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) ) @@ -432,6 +444,7 @@ class Message(BaseMessage): otid=otid, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) ) @@ -489,6 +502,7 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) ) else: @@ -506,6 +520,7 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) ) return messages @@ -549,6 +564,7 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) @staticmethod @@ -582,6 +598,7 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, is_err=self.is_err, + run_id=self.run_id, ) def _convert_system_message(self) -> SystemMessage: @@ -599,6 +616,7 @@ class Message(BaseMessage): otid=self.otid, sender_id=self.sender_id, step_id=self.step_id, + run_id=self.run_id, ) @staticmethod @@ -612,6 +630,7 @@ class Message(BaseMessage): name: Optional[str] = None, group_id: Optional[str] = None, tool_returns: Optional[List[ToolReturn]] = None, + run_id: Optional[str] = None, ) -> Message: """Convert a ChatCompletion message object into a Message object (synced to DB)""" if not created_at: @@ -673,6 +692,7 @@ class Message(BaseMessage): id=str(id), tool_returns=tool_returns, group_id=group_id, + run_id=run_id, ) else: return Message( @@ -687,6 +707,7 @@ class Message(BaseMessage): created_at=created_at, tool_returns=tool_returns, group_id=group_id, + run_id=run_id, ) elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None: @@ -722,6 +743,7 @@ class Message(BaseMessage): id=str(id), tool_returns=tool_returns, group_id=group_id, + run_id=run_id, ) else: return Message( @@ -736,6 +758,7 @@ class Message(BaseMessage): created_at=created_at, tool_returns=tool_returns, group_id=group_id, + run_id=run_id, ) else: @@ -771,6 +794,7 @@ class Message(BaseMessage): id=str(id), tool_returns=tool_returns, group_id=group_id, + run_id=run_id, ) else: return Message( @@ -785,6 +809,7 @@ class Message(BaseMessage): created_at=created_at, tool_returns=tool_returns, group_id=group_id, + run_id=run_id, ) def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict: diff --git a/letta/schemas/run.py b/letta/schemas/run.py index 433552aa..a03cd49d 100644 --- a/letta/schemas/run.py +++ b/letta/schemas/run.py @@ -1,62 +1,68 @@ +from datetime import datetime from typing import Optional -from pydantic import Field +from pydantic import ConfigDict, Field -from letta.schemas.enums import JobType -from letta.schemas.job import Job, JobBase, LettaRequestConfig +from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.enums import RunStatus +from letta.schemas.job import LettaRequestConfig +from letta.schemas.letta_base import LettaBase from letta.schemas.letta_stop_reason import StopReasonType -class RunBase(JobBase): - """Base class for Run schemas that inherits from JobBase but uses 'run' prefix for IDs""" - +class RunBase(LettaBase): __id_prefix__ = "run" - job_type: JobType = JobType.RUN class Run(RunBase): """ - Representation of a run, which is a job with a 'run' prefix in its ID. - Inherits all fields and behavior from Job except for the ID prefix. + Representation of a run - a conversation or processing session for an agent. + Runs track when agents process messages and maintain the relationship between agents, steps, and messages. Parameters: id (str): The unique identifier of the run (prefixed with 'run-'). - status (JobStatus): The status of the run. - created_at (datetime): The unix timestamp of when the run was created. - completed_at (datetime): The unix timestamp of when the run was completed. - user_id (str): The unique identifier of the user associated with the run. + status (JobStatus): The current status of the run. + created_at (datetime): The timestamp when the run was created. + completed_at (datetime): The timestamp when the run was completed. + agent_id (str): The unique identifier of the agent associated with the run. + stop_reason (StopReasonType): The reason why the run was stopped. + background (bool): Whether the run was created in background mode. + metadata (dict): Additional metadata for the run. + request_config (LettaRequestConfig): The request configuration for the run. """ id: str = RunBase.generate_id_field() - user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the run.") + + # Core run fields + status: RunStatus = Field(default=RunStatus.created, description="The current status of the run.") + created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the run was created.") + completed_at: Optional[datetime] = Field(None, description="The timestamp when the run was completed.") + + # Agent relationship + agent_id: str = Field(..., description="The unique identifier of the agent associated with the run.") + + # Run configuration + background: Optional[bool] = Field(None, description="Whether the run was created in background mode.") + metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="Additional metadata for the run.") request_config: Optional[LettaRequestConfig] = Field(None, description="The request configuration for the run.") stop_reason: Optional[StopReasonType] = Field(None, description="The reason why the run was stopped.") - @classmethod - def from_job(cls, job: Job) -> "Run": - """ - Convert a Job instance to a Run instance by replacing the ID prefix. - All other fields are copied as-is. + # Callback configuration + callback_url: Optional[str] = Field(None, description="If set, POST to this URL when the run completes.") + callback_sent_at: Optional[datetime] = Field(None, description="Timestamp when the callback was last attempted.") + callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.") + callback_error: Optional[str] = Field(None, description="Optional error message from attempting to POST the callback endpoint.") - Args: - job: The Job instance to convert + # Timing metrics (in nanoseconds for precision) + ttft_ns: Optional[int] = Field(None, description="Time to first token for a run in nanoseconds") + total_duration_ns: Optional[int] = Field(None, description="Total run duration in nanoseconds") - Returns: - A new Run instance with the same data but 'run-' prefix in ID - """ - # Convert job dict to exclude None values - job_data = job.model_dump(exclude_none=True) - # Create new Run instance with converted data - return cls(**job_data) +class RunUpdate(RunBase): + """Update model for Run.""" - def to_job(self) -> Job: - """ - Convert this Run instance to a Job instance by replacing the ID prefix. - All other fields are copied as-is. - - Returns: - A new Job instance with the same data but 'job-' prefix in ID - """ - run_data = self.model_dump(exclude_none=True) - return Job(**run_data) + status: Optional[RunStatus] = Field(None, description="The status of the run.") + completed_at: Optional[datetime] = Field(None, description="The timestamp when the run was completed.") + stop_reason: Optional[StopReasonType] = Field(None, description="The reason why the run was stopped.") + metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="Additional metadata for the run.") + model_config = ConfigDict(extra="ignore") # Ignores extra fields diff --git a/letta/schemas/step.py b/letta/schemas/step.py index dcd8b4b1..ec5e7bdc 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -18,8 +18,8 @@ class Step(StepBase): origin: Optional[str] = Field(None, description="The surface that this agent step was initiated from.") organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the step.") provider_id: Optional[str] = Field(None, description="The unique identifier of the provider that was configured for this step") - job_id: Optional[str] = Field( - None, description="The unique identifier of the job that this step belongs to. Only included for async calls." + run_id: Optional[str] = Field( + None, description="The unique identifier of the run that this step belongs to. Only included for async calls." ) agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.") provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.") diff --git a/letta/schemas/step_metrics.py b/letta/schemas/step_metrics.py index 4069ad77..fb791fc0 100644 --- a/letta/schemas/step_metrics.py +++ b/letta/schemas/step_metrics.py @@ -13,7 +13,7 @@ class StepMetrics(StepMetricsBase): id: str = Field(..., description="The id of the step this metric belongs to (matches steps.id).") organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.") provider_id: Optional[str] = Field(None, description="The unique identifier of the provider.") - job_id: Optional[str] = Field(None, description="The unique identifier of the job.") + run_id: Optional[str] = Field(None, description="The unique identifier of the run.") agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") step_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the step in nanoseconds.") llm_request_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the llm request in nanoseconds.") diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index c7e184ef..6f97085c 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -8,9 +8,9 @@ from typing import AsyncIterator, Dict, List, Optional from letta.data_sources.redis_client import AsyncRedisClient from letta.log import get_logger -from letta.schemas.enums import JobStatus +from letta.schemas.enums import RunStatus from letta.schemas.user import User -from letta.services.job_manager import JobManager +from letta.services.run_manager import RunManager from letta.utils import safe_create_task logger = get_logger(__name__) @@ -194,7 +194,7 @@ async def create_background_stream_processor( redis_client: AsyncRedisClient, run_id: str, writer: Optional[RedisSSEStreamWriter] = None, - job_manager: Optional[JobManager] = None, + run_manager: Optional[RunManager] = None, actor: Optional[User] = None, ) -> None: """ @@ -208,8 +208,8 @@ async def create_background_stream_processor( redis_client: Redis client instance run_id: The run ID to store chunks under writer: Optional pre-configured writer (creates new if not provided) - job_manager: Optional job manager for updating job status - actor: Optional actor for job status updates + run_manager: Optional run manager for updating run status + actor: Optional actor for run status updates """ if writer is None: writer = RedisSSEStreamWriter(redis_client) @@ -235,9 +235,9 @@ async def create_background_stream_processor( # Write error chunk # error_chunk = {"error": {"message": str(e)}} # Mark run_id terminal state - if job_manager and actor: - await job_manager.safe_update_job_status_async( - job_id=run_id, new_status=JobStatus.failed, actor=actor, metadata={"error": str(e)} + if run_manager and actor: + await run_manager.safe_update_run_status_async( + run_id=run_id, new_status=RunStatus.failed, actor=actor, metadata={"error": str(e)} ) error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"} diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 2596fd4b..240b286b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -30,13 +30,13 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry -from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent +from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.agent_file import AgentFileSchema from letta.schemas.block import Block, BlockUpdate -from letta.schemas.enums import JobType +from letta.schemas.enums import RunStatus from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles from letta.schemas.group import Group -from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig +from letta.schemas.job import LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse @@ -50,7 +50,7 @@ from letta.schemas.memory import ( ) from letta.schemas.message import MessageCreate, MessageSearchRequest, MessageSearchResult from letta.schemas.passage import Passage -from letta.schemas.run import Run +from letta.schemas.run import Run as PydanticRun, RunUpdate from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.user import User @@ -58,6 +58,7 @@ from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator from letta.server.server import SyncServer +from letta.services.run_manager import RunManager from letta.settings import settings from letta.utils import safe_create_shielded_task, safe_create_task, truncate_file_visible_content @@ -195,7 +196,7 @@ async def export_agent( if use_legacy_format: # Use the legacy serialization method try: - agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps) + agent = await server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps) return agent.model_dump() except NoResultFound: raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") @@ -432,8 +433,6 @@ async def create_agent( """ Create an agent. """ - # TODO remove - # agent.agent_type = AgentType.letta_v1_agent try: actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) return await server.create_agent_async(agent, actor=actor) @@ -1117,8 +1116,6 @@ async def modify_message( """ # TODO: support modifying tool calls/returns actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - - # TODO: implement return await server.message_manager.update_message_by_letta_message_async( message_id=message_id, letta_message_update=request, actor=actor ) @@ -1168,33 +1165,26 @@ async def send_message( # Create a new run for execution tracking if settings.track_agent_run: - job_status = JobStatus.created - run = await server.job_manager.create_job_async( - pydantic_job=Run( - user_id=actor.id, - status=job_status, + runs_manager = RunManager() + run = await runs_manager.create_run( + pydantic_run=PydanticRun( agent_id=agent_id, background=False, metadata={ - "job_type": "send_message", + "run_type": "send_message", }, - request_config=LettaRequestConfig( - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - include_return_message_types=request.include_return_message_types, - ), + request_config=LettaRequestConfig.from_letta_request(request), ), actor=actor, ) else: run = None - job_update_metadata = None # TODO (cliandy): clean this up redis_client = await get_redis_client() await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None) + run_update_metadata = None try: result = None if agent_eligible and model_compatible: @@ -1220,17 +1210,17 @@ async def send_message( assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, include_return_message_types=request.include_return_message_types, ) - job_status = result.stop_reason.stop_reason.run_status + run_status = result.stop_reason.stop_reason.run_status return result except PendingApprovalError as e: - job_update_metadata = {"error": str(e)} - job_status = JobStatus.failed + run_update_metadata = {"error": str(e)} + run_status = RunStatus.failed raise HTTPException( status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id} ) except Exception as e: - job_update_metadata = {"error": str(e)} - job_status = JobStatus.failed + run_update_metadata = {"error": str(e)} + run_status = RunStatus.failed raise finally: if settings.track_agent_run: @@ -1239,12 +1229,14 @@ async def send_message( else: # NOTE: we could also consider this an error? stop_reason = None - await server.job_manager.safe_update_job_status_async( - job_id=run.id, - new_status=job_status, + await server.run_manager.update_run_by_id_async( + run_id=run.id, + update=RunUpdate( + status=run_status, + metadata=run_update_metadata, + stop_reason=stop_reason, + ), actor=actor, - metadata=job_update_metadata, - stop_reason=stop_reason, ) @@ -1301,28 +1293,21 @@ async def send_message_streaming( ] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek"] - # Create a new job for execution tracking + # Create a new run for execution tracking if settings.track_agent_run: - job_status = JobStatus.created - run = await server.job_manager.create_job_async( - pydantic_job=Run( - user_id=actor.id, - status=job_status, + runs_manager = RunManager() + run = await runs_manager.create_run( + pydantic_run=PydanticRun( agent_id=agent_id, background=request.background or False, metadata={ - "job_type": "send_message_streaming", + "run_type": "send_message_streaming", }, - request_config=LettaRequestConfig( - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - include_return_message_types=request.include_return_message_types, - ), + request_config=LettaRequestConfig.from_letta_request(request), ), actor=actor, ) - job_update_metadata = None + run_update_metadata = None await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None) else: run = None @@ -1398,7 +1383,7 @@ async def send_message_streaming( stream_generator=raw_stream, redis_client=redis_client, run_id=run.id, - job_manager=server.job_manager, + run_manager=server.run_manager, actor=actor, ), label=f"background_stream_processor_{run.id}", @@ -1434,24 +1419,24 @@ async def send_message_streaming( include_return_message_types=request.include_return_message_types, ) if settings.track_agent_run: - job_status = JobStatus.running + run_status = RunStatus.running return result except PendingApprovalError as e: if settings.track_agent_run: - job_update_metadata = {"error": str(e)} - job_status = JobStatus.failed + run_update_metadata = {"error": str(e)} + run_status = RunStatus.failed raise HTTPException( status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id} ) except Exception as e: if settings.track_agent_run: - job_update_metadata = {"error": str(e)} - job_status = JobStatus.failed + run_update_metadata = {"error": str(e)} + run_status = RunStatus.failed raise finally: if settings.track_agent_run: - await server.job_manager.safe_update_job_status_async( - job_id=run.id, new_status=job_status, actor=actor, metadata=job_update_metadata + await server.run_manager.update_run_by_id_async( + run_id=run.id, update=RunUpdate(status=run_status, metadata=run_update_metadata), actor=actor ) @@ -1480,31 +1465,30 @@ async def cancel_agent_run( run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}") if run_id is None: logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.") - job_ids = await server.job_manager.list_jobs_async( + run_ids = await server.run_manager.list_runs( actor=actor, - statuses=[JobStatus.created, JobStatus.running], - job_type=JobType.RUN, + statuses=[RunStatus.created, RunStatus.running], ascending=False, - agent_ids=[agent_id], + agent_id=agent_id, # NOTE: this will override agent_ids if provided ) - run_ids = [Run.from_job(job).id for job in job_ids] + run_ids = [run.id for run in run_ids] else: run_ids = [run_id] results = {} for run_id in run_ids: - run = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor) + run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor) if run.metadata.get("lettuce") and settings.temporal_endpoint: client = await Client.connect( settings.temporal_endpoint, namespace=settings.temporal_namespace, api_key=settings.temporal_api_key, - tls=settings.temporal_tls, # This should be false for local runs + tls=True, # This should be false for local runs ) await client.cancel_workflow(run_id) - success = await server.job_manager.safe_update_job_status_async( - job_id=run_id, - new_status=JobStatus.cancelled, + success = await server.run_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=RunStatus.cancelled), actor=actor, ) results[run_id] = "cancelled" if success else "failed" @@ -1559,7 +1543,7 @@ async def _process_message_background( max_steps: int = DEFAULT_MAX_STEPS, include_return_message_types: list[MessageType] | None = None, ) -> None: - """Background task to process the message and update job status.""" + """Background task to process the message and update run status.""" request_start_timestamp_ns = get_utc_timestamp_ns() try: agent = await server.agent_manager.get_agent_by_id_async( @@ -1596,7 +1580,7 @@ async def _process_message_background( input_messages=messages, stream_steps=False, stream_tokens=False, - metadata={"job_id": run_id}, + metadata={"run_id": run_id}, # Support for AssistantMessage use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, @@ -1604,34 +1588,40 @@ async def _process_message_background( include_return_message_types=include_return_message_types, ) - job_update = JobUpdate( - status=JobStatus.completed, - completed_at=datetime.now(timezone.utc), - metadata={"result": result.model_dump(mode="json")}, + runs_manager = RunManager() + from letta.schemas.enums import RunStatus + + await runs_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=RunStatus.completed, stop_reason=result.stop_reason.stop_reason), + actor=actor, ) - await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor) except PendingApprovalError as e: - # Update job status to failed with specific error info - job_update = JobUpdate( - status=JobStatus.failed, - completed_at=datetime.now(timezone.utc), - metadata={"error": str(e), "error_code": "PENDING_APPROVAL", "pending_request_id": e.pending_request_id}, + # Update run status to failed with specific error info + runs_manager = RunManager() + from letta.schemas.enums import RunStatus + + await runs_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=RunStatus.failed), + actor=actor, ) - await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor) except Exception as e: - # Update job status to failed - job_update = JobUpdate( - status=JobStatus.failed, - completed_at=datetime.now(timezone.utc), - metadata={"error": str(e)}, + # Update run status to failed + runs_manager = RunManager() + from letta.schemas.enums import RunStatus + + await runs_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=RunStatus.failed), + actor=actor, ) - await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor) @router.post( "/{agent_id}/messages/async", - response_model=Run, + response_model=PydanticRun, operation_id="create_agent_message_async", ) async def send_message_async( @@ -1644,33 +1634,26 @@ async def send_message_async( Asynchronously process a user message and return a run object. The actual processing happens in the background, and the status can be checked using the run ID. - This is "asynchronous" in the sense that it's a background job and explicitly must be fetched by the run ID. - This is more like `send_message_job` + This is "asynchronous" in the sense that it's a background run and explicitly must be fetched by the run ID. """ MetricRegistry().user_message_counter.add(1, get_ctx_attributes()) actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - - # Create a new job + # Create a new run use_lettuce = headers.experimental_params.message_async and settings.temporal_endpoint is not None - run = Run( - user_id=actor.id, - status=JobStatus.created, + run = PydanticRun( callback_url=request.callback_url, agent_id=agent_id, background=True, # Async endpoints are always background metadata={ - "job_type": "send_message_async", - "agent_id": agent_id, + "run_type": "send_message_async", "lettuce": use_lettuce, }, - request_config=LettaRequestConfig( - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - include_return_message_types=request.include_return_message_types, - ), + request_config=LettaRequestConfig.from_letta_request(request), + ) + run = await server.run_manager.create_run( + pydantic_run=run, + actor=actor, ) - run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor) if use_lettuce: agent_state = await server.agent_manager.get_agent_by_id_async( @@ -1713,17 +1696,21 @@ async def send_message_async( # Don't mark as failed since the shielded task is still running except Exception as e: logger.error(f"Unhandled exception in background task for run {run.id}: {e}") - safe_create_task( - server.job_manager.update_job_by_id_async( - job_id=run.id, - job_update=JobUpdate( - status=JobStatus.failed, - completed_at=datetime.now(timezone.utc), - metadata={"error": str(e)}, - ), + from letta.services.run_manager import RunManager + + async def update_failed_run(): + runs_manager = RunManager() + from letta.schemas.enums import RunStatus + + await runs_manager.update_run_by_id_async( + run_id=run.id, + update=RunUpdate(status=RunStatus.failed), actor=actor, - ), - label=f"update_failed_job_{run.id}", + ) + + safe_create_task( + update_failed_run(), + label=f"update_failed_run_{run.id}", ) task.add_done_callback(handle_task_completion) diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 8e7f9d43..d5a6efac 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -8,7 +8,7 @@ from temporalio.client import Client from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.helpers.datetime_helpers import get_utc_time from letta.orm.errors import NoResultFound -from letta.schemas.enums import JobStatus, JobType +from letta.schemas.enums import RunStatus from letta.schemas.letta_message import LettaMessageUnion from letta.schemas.letta_request import RetrieveStreamRequest from letta.schemas.letta_stop_reason import StopReasonType @@ -23,6 +23,7 @@ from letta.server.rest_api.streaming_response import ( cancellation_aware_stream_wrapper, ) from letta.server.server import SyncServer +from letta.services.run_manager import RunManager from letta.settings import settings router = APIRouter(prefix="/runs", tags=["runs"]) @@ -49,27 +50,26 @@ async def list_runs( List all runs. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() + statuses = None if active: - statuses = [JobStatus.created, JobStatus.running] + statuses = [RunStatus.created, RunStatus.running] if agent_id: # NOTE: we are deprecating agent_ids so this will the primary path soon agent_ids = [agent_id] - jobs = await server.job_manager.list_jobs_async( + runs = await runs_manager.list_runs( actor=actor, + agent_ids=agent_ids, statuses=statuses, - job_type=JobType.RUN, limit=limit, before=before, after=after, - ascending=False, + ascending=ascending, stop_reason=stop_reason, - # agent_id=agent_id, - agent_ids=agent_ids, background=background, ) - runs = [Run.from_job(job) for job in jobs] return runs @@ -84,16 +84,16 @@ async def list_active_runs( List all active runs. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() if agent_id: agent_ids = [agent_id] else: agent_ids = None - active_runs = await server.job_manager.list_jobs_async( - actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN, agent_ids=agent_ids, background=background + active_runs = await runs_manager.list_runs( + actor=actor, statuses=[RunStatus.created, RunStatus.running], agent_ids=agent_ids, background=background ) - active_runs = [Run.from_job(job) for job in active_runs] return active_runs @@ -108,12 +108,13 @@ async def retrieve_run( Get the status of a run. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() try: - job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor) + run = await runs_manager.get_run_by_id(run_id=run_id, actor=actor) - use_lettuce = job.metadata.get("lettuce") and settings.temporal_endpoint - if use_lettuce and job.status not in [JobStatus.completed, JobStatus.failed, JobStatus.cancelled]: + use_lettuce = run.metadata and run.metadata.get("lettuce") and settings.temporal_endpoint + if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]: client = await Client.connect( settings.temporal_endpoint, namespace=settings.temporal_namespace, @@ -126,25 +127,17 @@ async def retrieve_run( desc = await handle.describe() # Map the status to our enum - job_status = JobStatus.created + run_status = RunStatus.created if desc.status.name == "RUNNING": - job_status = JobStatus.running + run_status = RunStatus.running elif desc.status.name == "COMPLETED": - job_status = JobStatus.completed + run_status = RunStatus.completed elif desc.status.name == "FAILED": - job_status = JobStatus.failed + run_status = RunStatus.failed elif desc.status.name == "CANCELED": - job_status = JobStatus.canceled - # elif desc.status.name == "TERMINATED": - # job_status = JobStatus.terminated - # elif desc.status.name == "TIMED_OUT": - # job_status = JobStatus.timed_out - # elif desc.status.name == "CONTINUED_AS_NEW": - # return WorkflowStatus.CONTINUED_AS_NEW - # else: - # return WorkflowStatus.UNKNOWN - job.status = job_status - return Run.from_job(job) + run_status = RunStatus.cancelled + run.status = run_status + return run except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") @@ -176,23 +169,11 @@ async def list_run_messages( ): """Get response messages associated with a run.""" actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - - try: - messages = await server.job_manager.get_run_messages( - run_id=run_id, - actor=actor, - limit=limit, - before=before, - after=after, - ascending=(order == "asc"), - ) - return messages - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) + return await server.run_manager.get_run_messages(run_id=run_id, actor=actor, before=before, after=after, limit=limit, order=order) @router.get("/{run_id}/usage", response_model=UsageStatistics, operation_id="retrieve_run_usage") -def retrieve_run_usage( +async def retrieve_run_usage( run_id: str, headers: HeaderParams = Depends(get_headers), server: "SyncServer" = Depends(get_letta_server), @@ -200,8 +181,14 @@ def retrieve_run_usage( """ Get usage statistics for a run. """ - actor = server.user_manager.get_user_or_default(user_id=headers.actor_id) - raise Exception("Not implemented") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() + + try: + usage = await runs_manager.get_run_usage(run_id=run_id, actor=actor) + return usage + except NoResultFound: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") @router.get( @@ -237,10 +224,11 @@ async def list_run_steps( raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() try: - steps = await server.job_manager.get_job_steps( - job_id=run_id, + steps = await runs_manager.get_run_steps( + run_id=run_id, actor=actor, limit=limit, before=before, @@ -262,10 +250,11 @@ async def delete_run( Delete a run by its run_id. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() try: - job = await server.job_manager.delete_job_by_id_async(job_id=run_id, actor=actor) - return Run.from_job(job) + run = await runs_manager.delete_run_by_id(run_id=run_id, actor=actor) + return run except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") @@ -309,13 +298,13 @@ async def retrieve_stream( server: "SyncServer" = Depends(get_letta_server), ): actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() + try: - job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor) + run = await runs_manager.get_run_by_id(run_id=run_id, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") - run = Run.from_job(job) - if not run.background: raise HTTPException(status_code=400, detail="Run was not created in background mode, so it cannot be retrieved.") @@ -345,8 +334,8 @@ async def retrieve_stream( if settings.enable_cancellation_aware_streaming: stream = cancellation_aware_stream_wrapper( stream_generator=stream, - job_manager=server.job_manager, - job_id=run_id, + run_manager=server.run_manager, + run_id=run_id, actor=actor, ) diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 4356e230..af31948f 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -52,7 +52,7 @@ async def create_voice_chat_completions( message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, - job_manager=server.job_manager, + run_manager=server.run_manager, passage_manager=server.passage_manager, actor=actor, ) diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 295e3f1f..ee332af7 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -13,23 +13,23 @@ from starlette.types import Send from letta.errors import LettaUnexpectedStreamCancellationError, PendingApprovalError from letta.log import get_logger -from letta.schemas.enums import JobStatus +from letta.schemas.enums import RunStatus from letta.schemas.letta_ping import LettaPing from letta.schemas.user import User from letta.server.rest_api.utils import capture_sentry_exception -from letta.services.job_manager import JobManager +from letta.services.run_manager import RunManager from letta.settings import settings from letta.utils import safe_create_task logger = get_logger(__name__) -class JobCancelledException(Exception): - """Exception raised when a job is explicitly cancelled (not due to client timeout)""" +class RunCancelledException(Exception): + """Exception raised when a run is explicitly cancelled (not due to client timeout)""" - def __init__(self, job_id: str, message: str = None): - self.job_id = job_id - super().__init__(message or f"Job {job_id} was explicitly cancelled") + def __init__(self, run_id: str, message: str = None): + self.run_id = run_id + super().__init__(message or f"Run {run_id} was explicitly cancelled") async def add_keepalive_to_stream( @@ -109,21 +109,21 @@ async def add_keepalive_to_stream( # TODO (cliandy) wrap this and handle types async def cancellation_aware_stream_wrapper( stream_generator: AsyncIterator[str | bytes], - job_manager: JobManager, - job_id: str, + run_manager: RunManager, + run_id: str, actor: User, cancellation_check_interval: float = 0.5, ) -> AsyncIterator[str | bytes]: """ - Wraps a stream generator to provide real-time job cancellation checking. + Wraps a stream generator to provide real-time run cancellation checking. - This wrapper periodically checks for job cancellation while streaming and + This wrapper periodically checks for run cancellation while streaming and can interrupt the stream at any point, not just at step boundaries. Args: stream_generator: The original stream generator to wrap - job_manager: Job manager instance for checking job status - job_id: ID of the job to monitor for cancellation + run_manager: Run manager instance for checking run status + run_id: ID of the run to monitor for cancellation actor: User/actor making the request cancellation_check_interval: How often to check for cancellation (seconds) @@ -131,7 +131,7 @@ async def cancellation_aware_stream_wrapper( Stream chunks from the original generator until cancelled Raises: - asyncio.CancelledError: If the job is cancelled during streaming + asyncio.CancelledError: If the run is cancelled during streaming """ last_cancellation_check = asyncio.get_event_loop().time() @@ -141,32 +141,32 @@ async def cancellation_aware_stream_wrapper( current_time = asyncio.get_event_loop().time() if current_time - last_cancellation_check >= cancellation_check_interval: try: - job = await job_manager.get_job_by_id_async(job_id=job_id, actor=actor) - if job.status == JobStatus.cancelled: - logger.info(f"Stream cancelled for job {job_id}, interrupting stream") + run = await run_manager.get_run_by_id_async(run_id=run_id, actor=actor) + if run.status == RunStatus.cancelled: + logger.info(f"Stream cancelled for run {run_id}, interrupting stream") # Send cancellation event to client cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"} yield f"data: {json.dumps(cancellation_event)}\n\n" - # Raise custom exception for explicit job cancellation - raise JobCancelledException(job_id, f"Job {job_id} was cancelled") + # Raise custom exception for explicit run cancellation + raise RunCancelledException(run_id, f"Run {run_id} was cancelled") except Exception as e: # Log warning but don't fail the stream if cancellation check fails - logger.warning(f"Failed to check job cancellation for job {job_id}: {e}") + logger.warning(f"Failed to check run cancellation for run {run_id}: {e}") last_cancellation_check = current_time yield chunk - except JobCancelledException: - # Re-raise JobCancelledException to distinguish from client timeout - logger.info(f"Stream for job {job_id} was explicitly cancelled and cleaned up") + except RunCancelledException: + # Re-raise RunCancelledException to distinguish from client timeout + logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up") raise except asyncio.CancelledError: # Re-raise CancelledError (likely client timeout) to ensure proper cleanup - logger.info(f"Stream for job {job_id} was cancelled (likely client timeout) and cleaned up") + logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up") raise except Exception as e: - logger.error(f"Error in cancellation-aware stream wrapper for job {job_id}: {e}") + logger.error(f"Error in cancellation-aware stream wrapper for run {run_id}: {e}") raise @@ -267,12 +267,12 @@ class StreamingResponseWithStatusCode(StreamingResponse): self._client_connected = False # Continue processing but don't try to send more data - # Handle explicit job cancellations (should not throw error) - except JobCancelledException as exc: - logger.info(f"Stream was explicitly cancelled for job {exc.job_id}") + # Handle explicit run cancellations (should not throw error) + except RunCancelledException as exc: + logger.info(f"Stream was explicitly cancelled for run {exc.run_id}") # Handle explicit cancellation gracefully without error more_body = False - cancellation_resp = {"message": "Job was cancelled"} + cancellation_resp = {"message": "Run was cancelled"} cancellation_event = f"event: cancelled\ndata: {json.dumps(cancellation_resp)}\n\n".encode(self.charset) if not self.response_started: await send( diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 3e9be813..45c634bb 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -190,6 +190,7 @@ def create_approval_request_message_from_llm_response( reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, step_id: str | None = None, + run_id: str = None, ) -> Message: # Construct the tool call with the assistant's message # Force set request_heartbeat in tool_args to calculated continue_stepping @@ -213,6 +214,7 @@ def create_approval_request_message_from_llm_response( tool_call_id=tool_call_id, created_at=get_utc_time(), step_id=step_id, + run_id=run_id, ) if pre_computed_assistant_message_id: approval_message.id = pre_computed_assistant_message_id @@ -230,6 +232,8 @@ def create_letta_messages_from_llm_response( function_response: Optional[str], timezone: str, actor: User, + run_id: str | None = None, + step_id: str = None, continue_stepping: bool = False, heartbeat_reason: Optional[str] = None, reasoning_content: Optional[ @@ -237,7 +241,6 @@ def create_letta_messages_from_llm_response( ] = None, pre_computed_assistant_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, - step_id: str | None = None, is_approval_response: bool | None = None, # force set request_heartbeat, useful for v2 loop to ensure matching tool rules force_set_request_heartbeat: bool = True, @@ -278,6 +281,7 @@ def create_letta_messages_from_llm_response( tool_call_id=tool_call_id, created_at=get_utc_time(), batch_item_id=llm_batch_item_id, + run_id=run_id, ) else: # Safeguard against empty text messages @@ -300,6 +304,7 @@ def create_letta_messages_from_llm_response( tool_call_id=None, created_at=get_utc_time(), batch_item_id=llm_batch_item_id, + run_id=run_id, ) else: assistant_message = None @@ -330,6 +335,7 @@ def create_letta_messages_from_llm_response( # func_return=tool_execution_result.func_return, ) ], + run_id=run_id, ) messages.append(tool_message) diff --git a/letta/server/server.py b/letta/server/server.py index 5ff6eb59..8b49c3a5 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -100,6 +100,7 @@ from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager from letta.services.passage_manager import PassageManager from letta.services.provider_manager import ProviderManager +from letta.services.run_manager import RunManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.step_manager import StepManager @@ -160,6 +161,7 @@ class SyncServer(object): self.sandbox_config_manager = SandboxConfigManager() self.message_manager = MessageManager() self.job_manager = JobManager() + self.run_manager = RunManager() self.agent_manager = AgentManager() self.archive_manager = ArchiveManager() self.provider_manager = ProviderManager() @@ -644,7 +646,7 @@ class SyncServer(object): actor = await self.user_manager.get_actor_or_default_async(actor_id=user_id) - records = await self.message_manager.list_messages_for_agent_async( + records = await self.message_manager.list_messages( agent_id=agent_id, actor=actor, after=after, @@ -683,7 +685,7 @@ class SyncServer(object): assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, include_err: Optional[bool] = None, ) -> Union[List[Message], List[LettaMessage]]: - records = await self.message_manager.list_messages_for_agent_async( + records = await self.message_manager.list_messages( agent_id=agent_id, actor=actor, after=after, @@ -1222,7 +1224,7 @@ class SyncServer(object): message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, - job_manager=self.job_manager, + run_manager=self.run_manager, passage_manager=self.passage_manager, actor=actor, sandbox_env_vars=tool_env_vars, diff --git a/letta/services/helpers/run_manager_helper.py b/letta/services/helpers/run_manager_helper.py new file mode 100644 index 00000000..b8d1fd29 --- /dev/null +++ b/letta/services/helpers/run_manager_helper.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Optional + +from sqlalchemy import asc, desc, nulls_last, select +from letta.settings import DatabaseChoice, settings + +from letta.orm.run import Run as RunModel +from letta.settings import DatabaseChoice, settings +from sqlalchemy import asc, desc +from typing import Optional + +from letta.services.helpers.agent_manager_helper import _cursor_filter + + +async def _apply_pagination_async( + query, + before: Optional[str], + after: Optional[str], + session, + ascending: bool = True, + sort_by: str = "created_at", +) -> any: + # Determine the sort column + if sort_by == "last_run_completion": + sort_column = RunModel.last_run_completion + sort_nulls_last = True # TODO: handle this as a query param eventually + else: + sort_column = RunModel.created_at + sort_nulls_last = False + + if after: + result = ( + await session.execute( + select(sort_column, RunModel.id).where(RunModel.id == after) + ) + ).first() + if result: + after_sort_value, after_id = result + # SQLite does not support as granular timestamping, so we need to round the timestamp + if settings.database_engine is DatabaseChoice.SQLITE and isinstance( + after_sort_value, datetime + ): + after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S") + query = query.where( + _cursor_filter( + sort_column, + RunModel.id, + after_sort_value, + after_id, + forward=ascending, + nulls_last=sort_nulls_last, + ) + ) + + if before: + result = ( + await session.execute( + select(sort_column, RunModel.id).where(RunModel.id == before) + ) + ).first() + if result: + before_sort_value, before_id = result + # SQLite does not support as granular timestamping, so we need to round the timestamp + if settings.database_engine is DatabaseChoice.SQLITE and isinstance( + before_sort_value, datetime + ): + before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S") + query = query.where( + _cursor_filter( + sort_column, + RunModel.id, + before_sort_value, + before_id, + forward=not ascending, + nulls_last=sort_nulls_last, + ) + ) + + # Apply ordering + order_fn = asc if ascending else desc + query = query.order_by( + nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), + order_fn(RunModel.id), + ) + return query diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 1d35de9f..615d0866 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -10,7 +10,6 @@ from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.job import Job as JobModel -from letta.orm.job_messages import JobMessage from letta.orm.message import Message as MessageModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step, Step as StepModel @@ -40,8 +39,6 @@ class JobManager: self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: """Create a new job based on the JobCreate schema.""" - from letta.orm.agents_runs import AgentsRuns - async with db_registry.async_session() as session: # Associate the job with the user pydantic_job.user_id = actor.id @@ -55,17 +52,11 @@ class JobManager: job_data = pydantic_job.model_dump(to_orm=True) # Remove agent_id from job_data as it's not a field in the Job ORM model - # The relationship is handled through the AgentsRuns association table job_data.pop("agent_id", None) job = JobModel(**job_data) job.organization_id = actor.organization_id job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database - # If this is a Run with an agent_id, create the agents_runs association - if agent_id and isinstance(pydantic_job, PydanticRun): - agents_run = AgentsRuns(agent_id=agent_id, run_id=job.id) - session.add(agents_run) - await session.commit() # Convert to pydantic first, then add agent_id if needed @@ -223,8 +214,6 @@ class JobManager: """List all jobs with optional pagination and status filter.""" from sqlalchemy import and_, or_, select - from letta.orm.agents_runs import AgentsRuns - async with db_registry.async_session() as session: # build base query query = select(JobModel).where(JobModel.user_id == actor.id).where(JobModel.job_type == job_type) @@ -247,11 +236,6 @@ class JobManager: column = column.op("->>")("source_id") query = query.where(column == source_id) - # If agent_id filter is provided, join with agents_runs table - if agent_ids: - query = query.join(AgentsRuns, JobModel.id == AgentsRuns.run_id) - query = query.where(AgentsRuns.agent_id.in_(agent_ids)) - # handle cursor-based pagination if before or after: # get cursor objects @@ -324,33 +308,6 @@ class JobManager: await job.hard_delete_async(db_session=session, actor=actor) return job.to_pydantic() - @enforce_types - @trace_method - async def add_messages_to_job_async(self, job_id: str, message_ids: List[str], actor: PydanticUser) -> None: - """ - Associate a message with a job by creating a JobMessage record. - Each message can only be associated with one job. - - Args: - job_id: The ID of the job - message_id: The ID of the message to associate - actor: The user making the request - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - if not message_ids: - return - - async with db_registry.async_session() as session: - # First verify job exists and user has access - await self._verify_job_access_async(session, job_id, actor, access=["write"]) - - # Create new JobMessage associations - job_messages = [JobMessage(job_id=job_id, message_id=message_id) for message_id in message_ids] - session.add_all(job_messages) - await session.commit() - @enforce_types @trace_method async def get_run_messages( @@ -570,57 +527,6 @@ class JobManager: finally: return result - @enforce_types - @trace_method - async def get_job_messages( - self, - job_id: str, - actor: PydanticUser, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 100, - role: Optional[MessageRole] = None, - ascending: bool = True, - ) -> List[PydanticMessage]: - """ - Get all messages associated with a job. - - Args: - job_id: The ID of the job to get messages for - actor: The user making the request - before: Cursor for pagination - after: Cursor for pagination - limit: Maximum number of messages to return - role: Optional filter for message role - ascending: Optional flag to sort in ascending order - - Returns: - List of messages associated with the job - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - async with db_registry.async_session() as session: - # Build filters - filters = {} - if role is not None: - filters["role"] = role - - # Get messages - messages = await MessageModel.list_async( - db_session=session, - before=before, - after=after, - ascending=ascending, - limit=limit, - actor=actor, - join_model=JobMessage, - join_conditions=[MessageModel.id == JobMessage.message_id, JobMessage.job_id == job_id], - **filters, - ) - - return [message.to_pydantic() for message in messages] - @enforce_types @trace_method async def get_job_steps( diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 89b74c02..956aeccd 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -307,6 +307,7 @@ class MessageManager: self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser, + run_id: Optional[str] = None, strict_mode: bool = False, project_id: Optional[str] = None, template_id: Optional[str] = None, @@ -669,8 +670,9 @@ class MessageManager: query_text: Optional[str] = None, limit: Optional[int] = 50, ascending: bool = True, + run_id: Optional[str] = None, ) -> List[PydanticMessage]: - return await self.list_messages_for_agent_async( + return await self.list_messages( agent_id=agent_id, actor=actor, after=after, @@ -679,14 +681,15 @@ class MessageManager: roles=[MessageRole.user], limit=limit, ascending=ascending, + run_id=run_id, ) @enforce_types @trace_method - async def list_messages_for_agent_async( + async def list_messages( self, - agent_id: str, actor: PydanticUser, + agent_id: Optional[str] = None, after: Optional[str] = None, before: Optional[str] = None, query_text: Optional[str] = None, @@ -695,9 +698,10 @@ class MessageManager: ascending: bool = True, group_id: Optional[str] = None, include_err: Optional[bool] = None, + run_id: Optional[str] = None, ) -> List[PydanticMessage]: """ - Most performant query to list messages for an agent by directly querying the Message table. + Most performant query to list messages by directly querying the Message table. This function filters by the agent_id (leveraging the index on messages.agent_id) and applies pagination using sequence_id as the cursor. @@ -715,6 +719,7 @@ class MessageManager: ascending: If True, sort by sequence_id ascending; if False, sort descending. group_id: Optional group ID to filter messages by group_id. include_err: Optional boolean to include errors and error statuses. Used for debugging only. + run_id: Optional run ID to filter messages by run_id. Returns: List[PydanticMessage]: A list of messages (converted via .to_pydantic()). @@ -725,15 +730,21 @@ class MessageManager: async with db_registry.async_session() as session: # Permission check: raise if the agent doesn't exist or actor is not allowed. - await validate_agent_exists_async(session, agent_id, actor) # Build a query that directly filters the Message table by agent_id. - query = select(MessageModel).where(MessageModel.agent_id == agent_id) + query = select(MessageModel) + + if agent_id: + await validate_agent_exists_async(session, agent_id, actor) + query = query.where(MessageModel.agent_id == agent_id) # If group_id is provided, filter messages by group_id. if group_id: query = query.where(MessageModel.group_id == group_id) + if run_id: + query = query.where(MessageModel.run_id == run_id) + if not include_err: query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None))) @@ -972,7 +983,7 @@ class MessageManager: except Exception as e: logger.error(f"Failed to search messages with Turbopuffer, falling back to SQL: {e}") # fall back to SQL search - messages = await self.list_messages_for_agent_async( + messages = await self.list_messages( agent_id=agent_id, actor=actor, query_text=query_text, @@ -992,7 +1003,7 @@ class MessageManager: return message_tuples else: # use sql-based search - messages = await self.list_messages_for_agent_async( + messages = await self.list_messages( agent_id=agent_id, actor=actor, query_text=query_text, diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py new file mode 100644 index 00000000..1272551d --- /dev/null +++ b/letta/services/run_manager.py @@ -0,0 +1,296 @@ +from datetime import datetime +from pickletools import pyunicode +from typing import List, Literal, Optional + +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from letta.helpers.datetime_helpers import get_utc_time +from letta.log import get_logger +from letta.orm.errors import NoResultFound +from letta.orm.message import Message as MessageModel +from letta.orm.run import Run as RunModel +from letta.orm.sqlalchemy_base import AccessType +from letta.orm.step import Step as StepModel +from letta.otel.tracing import log_event, trace_method +from letta.schemas.enums import MessageRole, RunStatus +from letta.schemas.job import LettaRequestConfig +from letta.schemas.letta_message import LettaMessage, LettaMessageUnion +from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType +from letta.schemas.message import Message as PydanticMessage +from letta.schemas.run import Run as PydanticRun, RunUpdate +from letta.schemas.step import Step as PydanticStep +from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.services.helpers.agent_manager_helper import validate_agent_exists_async +from letta.services.message_manager import MessageManager +from letta.services.step_manager import StepManager +from letta.utils import enforce_types + +logger = get_logger(__name__) + + +class RunManager: + """Manager class to handle business logic related to Runs.""" + + def __init__(self): + """Initialize the RunManager.""" + self.step_manager = StepManager() + self.message_manager = MessageManager() + + @enforce_types + async def create_run(self, pydantic_run: PydanticRun, actor: PydanticUser) -> PydanticRun: + """Create a new run.""" + async with db_registry.async_session() as session: + # Get agent_id from the pydantic object + agent_id = pydantic_run.agent_id + + # Verify agent exists before creating the run + await validate_agent_exists_async(session, agent_id, actor) + organization_id = actor.organization_id + + run_data = pydantic_run.model_dump(exclude_none=True) + # Handle metadata field mapping (Pydantic uses 'metadata', ORM uses 'metadata_') + if "metadata" in run_data: + run_data["metadata_"] = run_data.pop("metadata") + + run = RunModel(**run_data) + run.organization_id = organization_id + run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True) + await session.commit() + + return run.to_pydantic() + + @enforce_types + async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun: + """Get a run by its ID.""" + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) + if not run: + raise NoResultFound(f"Run with id {run_id} not found") + return run.to_pydantic() + + @enforce_types + async def list_runs( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + agent_ids: Optional[List[str]] = None, + statuses: Optional[List[RunStatus]] = None, + limit: Optional[int] = 50, + before: Optional[str] = None, + after: Optional[str] = None, + ascending: bool = False, + stop_reason: Optional[str] = None, + background: Optional[bool] = None, + ) -> List[PydanticRun]: + """List runs with filtering options.""" + async with db_registry.async_session() as session: + from sqlalchemy import select + + query = select(RunModel).filter(RunModel.organization_id == actor.organization_id) + + # Handle agent filtering + if agent_id: + agent_ids = [agent_id] + if agent_ids: + query = query.filter(RunModel.agent_id.in_(agent_ids)) + + # Filter by status + if statuses: + query = query.filter(RunModel.status.in_(statuses)) + + # Filter by stop reason + if stop_reason: + query = query.filter(RunModel.stop_reason == stop_reason) + + # Filter by background + if background is not None: + query = query.filter(RunModel.background == background) + + # Apply pagination + from letta.services.helpers.run_manager_helper import _apply_pagination_async + + query = await _apply_pagination_async(query, before, after, session, ascending=ascending) + + # Apply limit + if limit: + query = query.limit(limit) + + result = await session.execute(query) + runs = result.scalars().all() + return [run.to_pydantic() for run in runs] + + @enforce_types + async def delete_run(self, run_id: str, actor: PydanticUser) -> PydanticRun: + """Delete a run by its ID.""" + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) + if not run: + raise NoResultFound(f"Run with id {run_id} not found") + + pydantic_run = run.to_pydantic() + await run.hard_delete_async(db_session=session, actor=actor) + + return pydantic_run + + @enforce_types + async def update_run_by_id_async( + self, + run_id: str, + update: RunUpdate, + actor: PydanticUser, + ) -> PydanticRun: + """Update a run using a RunUpdate object.""" + + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor) + + # Check if this is a terminal update and whether we should dispatch a callback + needs_callback = False + callback_url = None + not_completed_before = not bool(run.completed_at) + is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed} + if is_terminal_update and not_completed_before and run.callback_url: + needs_callback = True + callback_url = run.callback_url + + # Housekeeping only when the run is actually completing + if not_completed_before and is_terminal_update: + if not update.stop_reason: + logger.warning(f"Run {run_id} completed without a stop reason") + if not update.completed_at: + logger.warning(f"Run {run_id} completed without a completed_at timestamp") + update.completed_at = get_utc_time().replace(tzinfo=None) + + # Update job attributes with only the fields that were explicitly set + update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + # Automatically update the completion timestamp if status is set to 'completed' + for key, value in update_data.items(): + # Ensure completed_at is timezone-naive for database compatibility + if key == "completed_at" and value is not None and hasattr(value, "replace"): + value = value.replace(tzinfo=None) + setattr(run, key, value) + + await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) + final_metadata = run.metadata_ + pydantic_run = run.to_pydantic() + await session.commit() + + # Dispatch callback outside of database session if needed + if needs_callback: + result = LettaResponse( + messages=await self.get_run_messages(run_id=run_id, actor=actor), + stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason), + usage=await self.get_run_usage(run_id=run_id, actor=actor), + ) + final_metadata["result"] = result.model_dump() + callback_info = { + "run_id": run_id, + "callback_url": callback_url, + "status": update.status, + "completed_at": get_utc_time().replace(tzinfo=None), + "metadata": final_metadata, + } + callback_result = await self._dispatch_callback_async(callback_info) + + # Update callback status in a separate transaction + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor) + run.callback_sent_at = callback_result["callback_sent_at"] + run.callback_status_code = callback_result.get("callback_status_code") + run.callback_error = callback_result.get("callback_error") + pydantic_run = run.to_pydantic() + await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) + await session.commit() + + return pydantic_run + + @trace_method + async def _dispatch_callback_async(self, callback_info: dict) -> dict: + """ + POST a standard JSON payload to callback_url and return callback status asynchronously. + """ + payload = { + "run_id": callback_info["run_id"], + "status": callback_info["status"], + "completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None, + "metadata": callback_info["metadata"], + } + + callback_sent_at = get_utc_time().replace(tzinfo=None) + result = {"callback_sent_at": callback_sent_at} + + try: + async with AsyncClient() as client: + log_event("POST callback dispatched", payload) + resp = await client.post(callback_info["callback_url"], json=payload, timeout=5.0) + log_event("POST callback finished") + result["callback_status_code"] = resp.status_code + except Exception as e: + error_message = f"Failed to dispatch callback for run {callback_info['run_id']} to {callback_info['callback_url']}: {e!s}" + logger.error(error_message) + result["callback_error"] = error_message + # Continue silently - callback failures should not affect run completion + finally: + return result + + @enforce_types + async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics: + """Get usage statistics for a run.""" + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) + if not run: + raise NoResultFound(f"Run with id {run_id} not found") + + steps = await self.step_manager.list_steps_async(run_id=run_id, actor=actor) + total_usage = LettaUsageStatistics() + for step in steps: + total_usage.prompt_tokens += step.prompt_tokens + total_usage.completion_tokens += step.completion_tokens + total_usage.total_tokens += step.total_tokens + total_usage.step_count += 1 + return total_usage + + @enforce_types + async def get_run_messages( + self, + run_id: str, + actor: PydanticUser, + limit: Optional[int] = 100, + before: Optional[str] = None, + after: Optional[str] = None, + order: Literal["asc", "desc"] = "asc", + ) -> List[LettaMessage]: + """Get the result of a run.""" + request_config = await self.get_run_request_config(run_id=run_id, actor=actor) + + messages = await self.message_manager.list_messages( + actor=actor, + run_id=run_id, + limit=limit, + before=before, + after=after, + ascending=(order == "asc"), + ) + letta_messages = PydanticMessage.to_letta_messages_from_list(messages, reverse=(order != "asc")) + + if request_config and request_config.include_return_message_types: + include_return_message_types_set = set(request_config.include_return_message_types) + letta_messages = [msg for msg in letta_messages if msg.message_type in include_return_message_types_set] + + return letta_messages + + @enforce_types + async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]: + """Get the letta request config from a run.""" + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) + if not run: + raise NoResultFound(f"Run with id {run_id} not found") + pydantic_run = run.to_pydantic() + return pydantic_run.request_config diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 0c50285c..3609e11f 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -8,7 +8,6 @@ from sqlalchemy.orm import Session from letta.helpers.singleton import singleton from letta.orm.errors import NoResultFound -from letta.orm.job import Job as JobModel from letta.orm.message import Message as MessageModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel @@ -48,6 +47,7 @@ class StepManager: feedback: Optional[Literal["positive", "negative"]] = None, has_feedback: Optional[bool] = None, project_id: Optional[str] = None, + run_id: Optional[str] = None, ) -> List[PydanticStep]: """List all jobs with optional pagination and status filter.""" async with db_registry.async_session() as session: @@ -62,6 +62,8 @@ class StepManager: filter_kwargs["feedback"] = feedback if project_id: filter_kwargs["project_id"] = project_id + if run_id: + filter_kwargs["run_id"] = run_id steps = await StepModel.list_async( db_session=session, before=before, @@ -75,6 +77,60 @@ class StepManager: ) return [step.to_pydantic() for step in steps] + @enforce_types + @trace_method + def log_step( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + provider_category: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + run_id: Optional[str] = None, + step_id: Optional[str] = None, + project_id: Optional[str] = None, + stop_reason: Optional[LettaStopReason] = None, + status: Optional[StepStatus] = None, + error_type: Optional[str] = None, + error_data: Optional[Dict] = None, + ) -> PydanticStep: + step_data = { + "origin": None, + "organization_id": actor.organization_id, + "agent_id": agent_id, + "provider_id": provider_id, + "provider_name": provider_name, + "provider_category": provider_category, + "model": model, + "model_endpoint": model_endpoint, + "context_window_limit": context_window_limit, + "completion_tokens": usage.completion_tokens, + "prompt_tokens": usage.prompt_tokens, + "total_tokens": usage.total_tokens, + "run_id": run_id, + "tags": [], + "tid": None, + "trace_id": get_trace_id(), # Get the current trace ID + "project_id": project_id, + "status": status if status else StepStatus.PENDING, + "error_type": error_type, + "error_data": error_data, + } + if step_id: + step_data["id"] = step_id + if stop_reason: + step_data["stop_reason"] = stop_reason.stop_reason + with db_registry.session() as session: + if run_id: + self._verify_run_access(session, run_id, actor, access=["write"]) + new_step = StepModel(**step_data) + new_step.create(session) + return new_step.to_pydantic() + @enforce_types @trace_method async def log_step_async( @@ -88,7 +144,7 @@ class StepManager: context_window_limit: int, usage: UsageStatistics, provider_id: Optional[str] = None, - job_id: Optional[str] = None, + run_id: Optional[str] = None, step_id: Optional[str] = None, project_id: Optional[str] = None, stop_reason: Optional[LettaStopReason] = None, @@ -110,7 +166,7 @@ class StepManager: "completion_tokens": usage.completion_tokens, "prompt_tokens": usage.prompt_tokens, "total_tokens": usage.total_tokens, - "job_id": job_id, + "run_id": run_id, "tags": [], "tid": None, "trace_id": get_trace_id(), # Get the current trace ID @@ -375,7 +431,7 @@ class StepManager: tool_execution_ns: Optional[int] = None, step_ns: Optional[int] = None, agent_id: Optional[str] = None, - job_id: Optional[str] = None, + run_id: Optional[str] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, base_template_id: Optional[str] = None, @@ -390,7 +446,7 @@ class StepManager: tool_execution_ns: Time spent on tool execution in nanoseconds step_ns: Total time for the step in nanoseconds agent_id: The ID of the agent - job_id: The ID of the job + run_id: The ID of the run project_id: The ID of the project template_id: The ID of the template base_template_id: The ID of the base template @@ -419,7 +475,7 @@ class StepManager: "id": step_id, "organization_id": actor.organization_id, "agent_id": agent_id or step.agent_id, - "job_id": job_id or step.job_id, + "run_id": run_id, "project_id": project_id or step.project_id, "llm_request_ns": llm_request_ns, "tool_execution_ns": tool_execution_ns, @@ -432,62 +488,66 @@ class StepManager: await metrics.create_async(session) return metrics.to_pydantic() - def _verify_job_access( + def _verify_run_access( self, session: Session, - job_id: str, + run_id: str, actor: PydanticUser, access: List[Literal["read", "write", "delete"]] = ["read"], - ) -> JobModel: + ): """ - Verify that a job exists and the user has the required access. + Verify that a run exists and the user has the required access. Args: session: The database session - job_id: The ID of the job to verify + run_id: The ID of the run to verify actor: The user making the request Returns: - The job if it exists and the user has access + The run if it exists and the user has access Raises: - NoResultFound: If the job does not exist or user does not have access + NoResultFound: If the run does not exist or user does not have access """ - job_query = select(JobModel).where(JobModel.id == job_id) - job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER) - job = session.execute(job_query).scalar_one_or_none() - if not job: - raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") - return job + from letta.orm.run import Run as RunModel + + run_query = select(RunModel).where(RunModel.id == run_id) + run_query = RunModel.apply_access_predicate(run_query, actor, access, AccessType.USER) + run = session.execute(run_query).scalar_one_or_none() + if not run: + raise NoResultFound(f"Run with id {run_id} does not exist or user does not have access") + return run @staticmethod - async def _verify_job_access_async( + async def _verify_run_access_async( session: AsyncSession, - job_id: str, + run_id: str, actor: PydanticUser, access: List[Literal["read", "write", "delete"]] = ["read"], - ) -> JobModel: + ): """ - Verify that a job exists and the user has the required access asynchronously. + Verify that a run exists and the user has the required access asynchronously. Args: session: The async database session - job_id: The ID of the job to verify + run_id: The ID of the run to verify actor: The user making the request Returns: - The job if it exists and the user has access + The run if it exists and the user has access Raises: - NoResultFound: If the job does not exist or user does not have access + NoResultFound: If the run does not exist or user does not have access """ - job_query = select(JobModel).where(JobModel.id == job_id) - job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER) - result = await session.execute(job_query) - job = result.scalar_one_or_none() - if not job: - raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") - return job + from letta.orm.run import Run as RunModel + + run_query = select(RunModel).where(RunModel.id == run_id) + run_query = RunModel.apply_access_predicate(run_query, actor, access, AccessType.USER) + result = await session.execute(run_query) + run = result.scalar_one_or_none() + if not run: + raise NoResultFound(f"Run with id {run_id} does not exist or user does not have access") + return run # noinspection PyTypeChecker @@ -512,7 +572,7 @@ class NoopStepManager(StepManager): context_window_limit: int, usage: UsageStatistics, provider_id: Optional[str] = None, - job_id: Optional[str] = None, + run_id: Optional[str] = None, step_id: Optional[str] = None, project_id: Optional[str] = None, stop_reason: Optional[LettaStopReason] = None, @@ -535,7 +595,7 @@ class NoopStepManager(StepManager): context_window_limit: int, usage: UsageStatistics, provider_id: Optional[str] = None, - job_id: Optional[str] = None, + run_id: Optional[str] = None, step_id: Optional[str] = None, project_id: Optional[str] = None, stop_reason: Optional[LettaStopReason] = None, diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index 27a24afa..fb3c4718 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -20,9 +20,9 @@ from letta.services.block_manager import BlockManager from letta.services.file_manager import FileManager from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.files_agents_manager import FileAgentManager -from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.run_manager import RunManager from letta.services.source_manager import SourceManager from letta.services.tool_executor.tool_executor_base import ToolExecutor from letta.utils import get_friendly_error_msg @@ -47,7 +47,7 @@ class LettaFileToolExecutor(ToolExecutor): message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, - job_manager: JobManager, + run_manager: RunManager, passage_manager: PassageManager, actor: User, ): @@ -55,7 +55,7 @@ class LettaFileToolExecutor(ToolExecutor): message_manager=message_manager, agent_manager=agent_manager, block_manager=block_manager, - job_manager=job_manager, + run_manager=run_manager, passage_manager=passage_manager, actor=actor, ) diff --git a/letta/services/tool_executor/multi_agent_tool_executor.py b/letta/services/tool_executor/multi_agent_tool_executor.py index 76747c66..3d90a0e1 100644 --- a/letta/services/tool_executor/multi_agent_tool_executor.py +++ b/letta/services/tool_executor/multi_agent_tool_executor.py @@ -7,10 +7,12 @@ from letta.schemas.enums import MessageRole from letta.schemas.letta_message import AssistantMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.message import MessageCreate +from letta.schemas.run import Run from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User +from letta.services.run_manager import RunManager from letta.services.tool_executor.tool_executor_base import ToolExecutor from letta.settings import settings from letta.utils import safe_create_task @@ -43,13 +45,15 @@ class LettaMultiAgentToolExecutor(ToolExecutor): # Execute the appropriate function function_args_copy = function_args.copy() # Make a copy to avoid modifying the original - function_response = await function_map[function_name](agent_state, **function_args_copy) + function_response = await function_map[function_name](agent_state, actor, **function_args_copy) return ToolExecutionResult( status="success", func_return=function_response, ) - async def send_message_to_agent_and_wait_for_reply(self, agent_state: AgentState, message: str, other_agent_id: str) -> str: + async def send_message_to_agent_and_wait_for_reply( + self, agent_state: AgentState, actor: User, message: str, other_agent_id: str + ) -> str: augmented_message = ( f"[Incoming message from agent with ID '{agent_state.id}' - to reply to this message, " f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " @@ -57,10 +61,10 @@ class LettaMultiAgentToolExecutor(ToolExecutor): ) other_agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=other_agent_id, actor=self.actor) - return str(await self._process_agent(agent_state=other_agent_state, message=augmented_message)) + return str(await self._process_agent(agent_state=other_agent_state, message=augmented_message, actor=actor)) async def send_message_to_agents_matching_tags_async( - self, agent_state: AgentState, message: str, match_all: List[str], match_some: List[str] + self, agent_state: AgentState, actor: User, message: str, match_all: List[str], match_some: List[str] ) -> str: # Find matching agents matching_agents = await self.agent_manager.list_agents_matching_tags_async( @@ -78,20 +82,34 @@ class LettaMultiAgentToolExecutor(ToolExecutor): # Run concurrent requests and collect their return values. # Note: Do not wrap with safe_create_task here — it swallows return values (returns None). - coros = [self._process_agent(agent_state=a_state, message=augmented_message) for a_state in matching_agents] + coros = [self._process_agent(agent_state=a_state, message=augmented_message, actor=actor) for a_state in matching_agents] results = await asyncio.gather(*coros) return str(results) - async def _process_agent(self, agent_state: AgentState, message: str) -> Dict[str, Any]: + async def _process_agent(self, agent_state: AgentState, message: str, actor: User) -> Dict[str, Any]: from letta.agents.letta_agent_v2 import LettaAgentV2 try: + runs_manager = RunManager() + run = await runs_manager.create_run( + pydantic_run=Run( + agent_id=agent_state.id, + background=False, + metadata={ + "run_type": "agent_send_message_to_agent", # TODO: Make this a constant + }, + ), + actor=actor, + ) + letta_agent = LettaAgentV2( agent_state=agent_state, actor=self.actor, ) - letta_response = await letta_agent.step([MessageCreate(role=MessageRole.system, content=[TextContent(text=message)])]) + letta_response = await letta_agent.step( + [MessageCreate(role=MessageRole.system, content=[TextContent(text=message)])], run_id=run.id + ) messages = letta_response.messages send_message_content = [message.content for message in messages if isinstance(message, AssistantMessage)] @@ -108,7 +126,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor): "type": type(e).__name__, } - async def send_message_to_agent_async(self, agent_state: AgentState, message: str, other_agent_id: str) -> str: + async def send_message_to_agent_async(self, agent_state: AgentState, actor: User, message: str, other_agent_id: str) -> str: if settings.environment == "PRODUCTION": raise RuntimeError("This tool is not allowed to be run on Letta Cloud.") @@ -122,7 +140,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor): other_agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=other_agent_id, actor=self.actor) task = safe_create_task( - self._process_agent(agent_state=other_agent_state, message=prefixed), label=f"send_message_to_{other_agent_id}" + self._process_agent(agent_state=other_agent_state, message=prefixed, actor=actor), label=f"send_message_to_{other_agent_id}" ) task.add_done_callback(lambda t: (logger.error(f"Async send_message task failed: {t.exception()}") if t.exception() else None)) diff --git a/letta/services/tool_executor/tool_execution_manager.py b/letta/services/tool_executor/tool_execution_manager.py index 00433c83..e149cd02 100644 --- a/letta/services/tool_executor/tool_execution_manager.py +++ b/letta/services/tool_executor/tool_execution_manager.py @@ -16,9 +16,9 @@ from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager -from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.run_manager import RunManager from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor from letta.services.tool_executor.composio_tool_executor import ExternalComposioToolExecutor from letta.services.tool_executor.core_tool_executor import LettaCoreToolExecutor @@ -51,7 +51,7 @@ class ToolExecutorFactory: message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, - job_manager: JobManager, + run_manager: RunManager, passage_manager: PassageManager, actor: User, ) -> ToolExecutor: @@ -61,7 +61,7 @@ class ToolExecutorFactory: message_manager=message_manager, agent_manager=agent_manager, block_manager=block_manager, - job_manager=job_manager, + run_manager=run_manager, passage_manager=passage_manager, actor=actor, ) @@ -75,7 +75,7 @@ class ToolExecutionManager: message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, - job_manager: JobManager, + run_manager: RunManager, passage_manager: PassageManager, actor: User, agent_state: Optional[AgentState] = None, @@ -85,7 +85,7 @@ class ToolExecutionManager: self.message_manager = message_manager self.agent_manager = agent_manager self.block_manager = block_manager - self.job_manager = job_manager + self.run_manager = run_manager self.passage_manager = passage_manager self.agent_state = agent_state self.logger = get_logger(__name__) @@ -107,7 +107,7 @@ class ToolExecutionManager: message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, - job_manager=self.job_manager, + run_manager=self.run_manager, passage_manager=self.passage_manager, actor=self.actor, ) diff --git a/letta/services/tool_executor/tool_executor_base.py b/letta/services/tool_executor/tool_executor_base.py index 452ce681..38bff595 100644 --- a/letta/services/tool_executor/tool_executor_base.py +++ b/letta/services/tool_executor/tool_executor_base.py @@ -8,9 +8,9 @@ from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager -from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.run_manager import RunManager class ToolExecutor(ABC): @@ -21,14 +21,14 @@ class ToolExecutor(ABC): message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, - job_manager: JobManager, + run_manager: RunManager, passage_manager: PassageManager, actor: User, ): self.message_manager = message_manager self.agent_manager = agent_manager self.block_manager = block_manager - self.job_manager = job_manager + self.run_manager = run_manager self.passage_manager = passage_manager self.actor = actor diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index b9adb176..e0e218ce 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -8,8 +8,10 @@ from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.config import LettaConfig from letta.schemas.letta_message import ToolCallMessage from letta.schemas.message import MessageCreate +from letta.schemas.run import Run from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule from letta.server.server import SyncServer +from letta.services.run_manager import RunManager from letta.services.telemetry_manager import NoopTelemetryManager from tests.helpers.endpoints_helper import ( assert_invoked_function_call, @@ -244,8 +246,17 @@ async def run_agent_step(agent_state, input_messages, actor): actor=actor, ) + run = Run( + agent_id=agent_state.id, + ) + run = await RunManager().create_run( + pydantic_run=run, + actor=actor, + ) + return await agent_loop.step( input_messages, + run_id=run.id, max_steps=50, use_assistant_message=False, ) diff --git a/tests/integration_test_builtin_tools.py b/tests/integration_test_builtin_tools.py index 3885cda5..92efbf3a 100644 --- a/tests/integration_test_builtin_tools.py +++ b/tests/integration_test_builtin_tools.py @@ -303,7 +303,7 @@ async def test_web_search_uses_exa(): message_manager=MagicMock(), agent_manager=MagicMock(), block_manager=MagicMock(), - job_manager=MagicMock(), + run_manager=MagicMock(), passage_manager=MagicMock(), actor=MagicMock(), ) diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 5b545aff..209cd49b 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -1511,53 +1511,47 @@ def test_async_greeting_with_assistant_message( ) run = wait_for_run_completion(client, run.id) - result = run.metadata.get("result") - assert result is not None, "Run metadata missing 'result' key" - - messages = cast_message_dict_to_messages(result["messages"]) - assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) - messages = client.runs.messages.list(run_id=run.id) + usage = client.runs.usage.retrieve(run_id=run.id) + assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) -@pytest.mark.parametrize( - "llm_config", - TESTED_LLM_CONFIGS, - ids=[c.model for c in TESTED_LLM_CONFIGS], -) -def test_async_greeting_without_assistant_message( - disable_e2b_api_key: Any, - client: Letta, - agent_state: AgentState, - llm_config: LLMConfig, -) -> None: - """ - Tests sending a message as an asynchronous job using the synchronous client. - Waits for job completion and asserts that the result messages are as expected. - """ - last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) - client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) - - run = client.agents.messages.create_async( - agent_id=agent_state.id, - messages=USER_MESSAGE_FORCE_REPLY, - use_assistant_message=False, - ) - run = wait_for_run_completion(client, run.id) - - result = run.metadata.get("result") - assert result is not None, "Run metadata missing 'result' key" - - messages = cast_message_dict_to_messages(result["messages"]) - assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) - - messages = client.runs.messages.list(run_id=run.id) - assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) - messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) - assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) +# NOTE: deprecated in preparation of letta_v1_agent +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# def test_async_greeting_without_assistant_message( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# llm_config: LLMConfig, +# ) -> None: +# """ +# Tests sending a message as an asynchronous job using the synchronous client. +# Waits for job completion and asserts that the result messages are as expected. +# """ +# last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# run = client.agents.messages.create_async( +# agent_id=agent_state.id, +# messages=USER_MESSAGE_FORCE_REPLY, +# use_assistant_message=False, +# ) +# run = wait_for_run_completion(client, run.id) +# +# messages = client.runs.messages.list(run_id=run.id) +# assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) +# +# messages = client.runs.messages.list(run_id=run.id) +# assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) +# messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) +# assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) @pytest.mark.parametrize( @@ -1600,13 +1594,6 @@ def test_async_tool_call( request_options={"timeout_in_seconds": 300}, ) run = wait_for_run_completion(client, run.id) - - result = run.metadata.get("result") - assert result is not None, "Run metadata missing 'result' key" - - messages = cast_message_dict_to_messages(result["messages"]) - assert_tool_call_response(messages, llm_config=llm_config) - messages = client.runs.messages.list(run_id=run.id) assert_tool_call_response(messages, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) @@ -1737,10 +1724,7 @@ def test_async_greeting_with_callback_url( run = wait_for_run_completion(client, run.id) # Validate job completed successfully - result = run.metadata.get("result") - assert result is not None, "Run metadata missing 'result' key" - - messages = cast_message_dict_to_messages(result["messages"]) + messages = client.runs.messages.list(run_id=run.id) assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) # Validate callback was received @@ -1752,13 +1736,13 @@ def test_async_greeting_with_callback_url( callback_data = callback["data"] # Check required fields - assert "job_id" in callback_data, "Callback missing 'job_id' field" + assert "run_id" in callback_data, "Callback missing 'run_id' field" assert "status" in callback_data, "Callback missing 'status' field" assert "completed_at" in callback_data, "Callback missing 'completed_at' field" assert "metadata" in callback_data, "Callback missing 'metadata' field" # Validate field values - assert callback_data["job_id"] == run.id, f"Job ID mismatch: {callback_data['job_id']} != {run.id}" + assert callback_data["run_id"] == run.id, f"Job ID mismatch: {callback_data['run_id']} != {run.id}" assert callback_data["status"] == "completed", f"Expected status 'completed', got {callback_data['status']}" assert callback_data["completed_at"] is not None, "completed_at should not be None" assert callback_data["metadata"] is not None, "metadata should not be None" @@ -1766,7 +1750,8 @@ def test_async_greeting_with_callback_url( # Validate that callback metadata contains the result assert "result" in callback_data["metadata"], "Callback metadata missing 'result' field" callback_result = callback_data["metadata"]["result"] - assert callback_result == result, "Callback result doesn't match job result" + callback_messages = cast_message_dict_to_messages(callback_result["messages"]) + assert callback_messages == messages, "Callback result doesn't match job result" # Validate HTTP headers headers = callback["headers"] diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 34e993c3..0f581f77 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -150,7 +150,7 @@ async def test_sleeptime_group_chat(client): run_ids.extend(response.usage.run_ids or []) runs = client.runs.list() - agent_runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] == sleeptime_agent_id] + agent_runs = [run for run in runs if run.agent_id == sleeptime_agent_id] assert len(agent_runs) == len(run_ids) # 6. Verify run status after sleep diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 87105bbc..0ac684c1 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -900,7 +900,7 @@ async def test_message_embedding_without_config(server, default_user, sarah_agen assert all(msg.agent_id == sarah_agent.id for msg in created) # Messages should be in SQL - sql_messages = await server.message_manager.list_messages_for_agent_async( + sql_messages = await server.message_manager.list_messages( agent_id=sarah_agent.id, actor=default_user, limit=10, diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 9e88f367..363c4a33 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -266,7 +266,7 @@ async def test_model_compatibility(model, message, server, server_url, actor, ro print(chunk.choices[0].delta.content) # Get the messages and assert based on the message type - messages = await server.message_manager.list_messages_for_agent_async(agent_id=main_agent.id, actor=actor) + messages = await server.message_manager.list_messages(agent_id=main_agent.id, actor=actor) # Find user message with our request user_messages = [msg for msg in messages if msg.role == MessageRole.user and message in str(msg.content)] assert len(user_messages) >= 1, f"Should find user message containing: {message}" diff --git a/tests/managers/conftest.py b/tests/managers/conftest.py index 36821ee1..6cc79435 100644 --- a/tests/managers/conftest.py +++ b/tests/managers/conftest.py @@ -22,7 +22,7 @@ from letta.orm import Base from letta.schemas.agent import CreateAgent from letta.schemas.block import Block as PydanticBlock, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import JobStatus, MessageRole +from letta.schemas.enums import JobStatus, MessageRole, RunStatus from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.job import BatchJob, Job as PydanticJob @@ -524,13 +524,13 @@ async def default_job(server: SyncServer, default_user): @pytest.fixture -async def default_run(server: SyncServer, default_user): +async def default_run(server: SyncServer, default_user, sarah_agent): """Create and return a default run.""" run_pydantic = PydanticRun( - user_id=default_user.id, - status=JobStatus.pending, + agent_id=sarah_agent.id, + status=RunStatus.created, ) - run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user) + run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user) yield run diff --git a/tests/managers/test_job_manager.py b/tests/managers/test_job_manager.py index 58502147..58c48df0 100644 --- a/tests/managers/test_job_manager.py +++ b/tests/managers/test_job_manager.py @@ -358,39 +358,6 @@ async def test_list_jobs_filter_by_type(server: SyncServer, default_user, defaul assert jobs[0].id == run.id -@pytest.mark.asyncio -async def test_list_jobs_by_stop_reason(server: SyncServer, sarah_agent, default_user): - """Test listing jobs by stop reason.""" - - run_pydantic = PydanticRun( - user_id=default_user.id, - status=JobStatus.pending, - job_type=JobType.RUN, - stop_reason=StopReasonType.requires_approval, - agent_id=sarah_agent.id, - background=True, - ) - run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user) - assert run.stop_reason == StopReasonType.requires_approval - assert run.background == True - assert run.agent_id == sarah_agent.id - - # list jobs by stop reason - jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, stop_reason=StopReasonType.requires_approval) - assert len(jobs) == 1 - assert jobs[0].id == run.id - - # list jobs by background - jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, background=True) - assert len(jobs) == 1 - assert jobs[0].id == run.id - - # list jobs by agent_id - jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, agent_ids=[sarah_agent.id]) - assert len(jobs) == 1 - assert jobs[0].id == run.id - - async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user): """Test that job callbacks are properly dispatched when a job is completed.""" captured = {} @@ -445,524 +412,7 @@ async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user): # ====================================================================================================================== -@pytest.mark.asyncio -async def test_job_messages_add(server: SyncServer, default_run, hello_world_message_fixture, default_user): - """Test adding a message to a job.""" - # Add message to job - await server.job_manager.add_messages_to_job_async( - job_id=default_run.id, - message_ids=[hello_world_message_fixture.id], - actor=default_user, - ) - # Verify message was added - messages = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ) - assert len(messages) == 1 - assert messages[0].id == hello_world_message_fixture.id - assert messages[0].content[0].text == hello_world_message_fixture.content[0].text - - -@pytest.mark.asyncio -async def test_job_messages_pagination(server: SyncServer, default_run, default_user, sarah_agent): - """Test pagination of job messages.""" - # Create multiple messages - message_ids = [] - for i in range(5): - message = PydanticMessage( - agent_id=sarah_agent.id, - role=MessageRole.user, - content=[TextContent(text=f"Test message {i}")], - ) - msg = await server.message_manager.create_many_messages_async([message], actor=default_user) - message_ids.append(msg[0].id) - - # Add message to job - await server.job_manager.add_messages_to_job_async( - job_id=default_run.id, - message_ids=[msg[0].id], - actor=default_user, - ) - - # Test pagination with limit - messages = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - limit=2, - ) - assert len(messages) == 2 - assert messages[0].id == message_ids[0] - assert messages[1].id == message_ids[1] - - # Test pagination with cursor - first_page = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - limit=2, - ascending=True, # [M0, M1] - ) - assert len(first_page) == 2 - assert first_page[0].id == message_ids[0] - assert first_page[1].id == message_ids[1] - assert first_page[0].created_at <= first_page[1].created_at - - last_page = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - limit=2, - ascending=False, # [M4, M3] - ) - assert len(last_page) == 2 - assert last_page[0].id == message_ids[4] - assert last_page[1].id == message_ids[3] - assert last_page[0].created_at >= last_page[1].created_at - - first_page_ids = set(msg.id for msg in first_page) - last_page_ids = set(msg.id for msg in last_page) - assert first_page_ids.isdisjoint(last_page_ids) - - # Test middle page using both before and after - middle_page = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - before=last_page[-1].id, # M3 - after=first_page[0].id, # M0 - ascending=True, # [M1, M2] - ) - assert len(middle_page) == 2 # Should include message between first and last pages - assert middle_page[0].id == message_ids[1] - assert middle_page[1].id == message_ids[2] - head_tail_msgs = first_page_ids.union(last_page_ids) - assert middle_page[1].id not in head_tail_msgs - assert middle_page[0].id in first_page_ids - - # Test descending order for middle page - middle_page = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - before=last_page[-1].id, # M3 - after=first_page[0].id, # M0 - ascending=False, # [M2, M1] - ) - assert len(middle_page) == 2 # Should include message between first and last pages - assert middle_page[0].id == message_ids[2] - assert middle_page[1].id == message_ids[1] - - # Test getting earliest messages - msg_3 = last_page[-1].id - earliest_msgs = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ascending=False, - before=msg_3, # Get messages after M3 in descending order - ) - assert len(earliest_msgs) == 3 # Should get M2, M1, M0 - assert all(m.id not in last_page_ids for m in earliest_msgs) - assert earliest_msgs[0].created_at > earliest_msgs[1].created_at > earliest_msgs[2].created_at - - # Test getting earliest messages with ascending order - earliest_msgs_ascending = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ascending=True, - before=msg_3, # Get messages before M3 in ascending order - ) - assert len(earliest_msgs_ascending) == 3 # Should get M0, M1, M2 - assert all(m.id not in last_page_ids for m in earliest_msgs_ascending) - assert earliest_msgs_ascending[0].created_at < earliest_msgs_ascending[1].created_at < earliest_msgs_ascending[2].created_at - - -@pytest.mark.asyncio -async def test_job_messages_ordering(server: SyncServer, default_run, default_user, sarah_agent): - """Test that messages are ordered by created_at.""" - # Create messages with different timestamps - base_time = datetime.now(timezone.utc) - message_times = [ - base_time - timedelta(minutes=2), - base_time - timedelta(minutes=1), - base_time, - ] - - for i, created_at in enumerate(message_times): - message = PydanticMessage( - role=MessageRole.user, - content=[TextContent(text="Test message")], - agent_id=sarah_agent.id, - created_at=created_at, - ) - msg = await server.message_manager.create_many_messages_async([message], actor=default_user) - - # Add message to job - await server.job_manager.add_messages_to_job_async( - job_id=default_run.id, - message_ids=[msg[0].id], - actor=default_user, - ) - - # Verify messages are returned in chronological order - returned_messages = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ) - - assert len(returned_messages) == 3 - assert returned_messages[0].created_at < returned_messages[1].created_at - assert returned_messages[1].created_at < returned_messages[2].created_at - - # Verify messages are returned in descending order - returned_messages = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ascending=False, - ) - - assert len(returned_messages) == 3 - assert returned_messages[0].created_at > returned_messages[1].created_at - assert returned_messages[1].created_at > returned_messages[2].created_at - - -@pytest.mark.asyncio -async def test_job_messages_empty(server: SyncServer, default_run, default_user): - """Test getting messages for a job with no messages.""" - messages = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ) - assert len(messages) == 0 - - -@pytest.mark.asyncio -async def test_job_messages_add_duplicate(server: SyncServer, default_run, hello_world_message_fixture, default_user): - """Test adding the same message to a job twice.""" - # Add message to job first time - await server.job_manager.add_messages_to_job_async( - job_id=default_run.id, - message_ids=[hello_world_message_fixture.id], - actor=default_user, - ) - - # Attempt to add same message again - with pytest.raises(IntegrityError): - await server.job_manager.add_messages_to_job_async( - job_id=default_run.id, - message_ids=[hello_world_message_fixture.id], - actor=default_user, - ) - - -@pytest.mark.asyncio -async def test_job_messages_filter(server: SyncServer, default_run, default_user, sarah_agent): - """Test getting messages associated with a job.""" - # Create test messages with different roles and tool calls - messages = [ - PydanticMessage( - role=MessageRole.user, - content=[TextContent(text="Hello")], - agent_id=sarah_agent.id, - ), - PydanticMessage( - role=MessageRole.assistant, - content=[TextContent(text="Hi there!")], - agent_id=sarah_agent.id, - ), - PydanticMessage( - role=MessageRole.assistant, - content=[TextContent(text="Let me help you with that")], - agent_id=sarah_agent.id, - tool_calls=[ - OpenAIToolCall( - id="call_1", - type="function", - function=OpenAIFunction( - name="test_tool", - arguments='{"arg1": "value1"}', - ), - ) - ], - ), - ] - - # Add messages to job - for msg in messages: - created_msg = await server.message_manager.create_many_messages_async([msg], actor=default_user) - await server.job_manager.add_messages_to_job_async( - job_id=default_run.id, - message_ids=[created_msg[0].id], - actor=default_user, - ) - - # Test getting all messages - all_messages = await server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - ) - assert len(all_messages) == 3 - - # Test filtering by role - user_messages = await server.job_manager.get_job_messages(job_id=default_run.id, actor=default_user, role=MessageRole.user) - assert len(user_messages) == 1 - assert user_messages[0].role == MessageRole.user - - # Test limit - limited_messages = await server.job_manager.get_job_messages(job_id=default_run.id, actor=default_user, limit=2) - assert len(limited_messages) == 2 - - -@pytest.mark.asyncio -async def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent): - """Test getting messages for a run with request config.""" - # Create a run with custom request config - run = await server.job_manager.create_job_async( - pydantic_job=PydanticRun( - user_id=default_user.id, - status=JobStatus.created, - request_config=LettaRequestConfig( - use_assistant_message=False, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" - ), - ), - actor=default_user, - ) - - # Add some messages - messages = [ - PydanticMessage( - agent_id=sarah_agent.id, - role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, - content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], - tool_calls=( - [{"type": "function", "id": f"call_{i // 2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] - if i % 2 == 1 - else None - ), - tool_call_id=f"call_{i // 2}" if i % 2 == 0 else None, - ) - for i in range(4) - ] - - created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user) - for msg in created_msg: - await server.job_manager.add_messages_to_job_async( - job_id=run.id, - message_ids=[msg.id], - actor=default_user, - ) - - # Get messages and verify they're converted correctly - result = await server.job_manager.get_run_messages(run_id=run.id, actor=default_user) - - # Verify correct number of messages. Assistant messages should be parsed - assert len(result) == 6 - - # Verify assistant messages are parsed according to request config - tool_call_messages = [msg for msg in result if msg.message_type == "tool_call_message"] - reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"] - assert len(tool_call_messages) == 2 - assert len(reasoning_messages) == 2 - for msg in tool_call_messages: - assert msg.tool_call is not None - assert msg.tool_call.name == "custom_tool" - - -@pytest.mark.asyncio -async def test_get_run_messages_with_assistant_message(server: SyncServer, default_user: PydanticUser, sarah_agent): - """Test getting messages for a run with request config.""" - # Create a run with custom request config - run = await server.job_manager.create_job_async( - pydantic_job=PydanticRun( - user_id=default_user.id, - status=JobStatus.created, - request_config=LettaRequestConfig( - use_assistant_message=True, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" - ), - ), - actor=default_user, - ) - - # Add some messages - messages = [ - PydanticMessage( - agent_id=sarah_agent.id, - role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, - content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], - tool_calls=( - [{"type": "function", "id": f"call_{i // 2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] - if i % 2 == 1 - else None - ), - tool_call_id=f"call_{i // 2}" if i % 2 == 0 else None, - ) - for i in range(4) - ] - - created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user) - for msg in created_msg: - await server.job_manager.add_messages_to_job_async( - job_id=run.id, - message_ids=[msg.id], - actor=default_user, - ) - - # Get messages and verify they're converted correctly - result = await server.job_manager.get_run_messages(run_id=run.id, actor=default_user) - - # Verify correct number of messages. Assistant messages should be parsed - assert len(result) == 4 - - # Verify assistant messages are parsed according to request config - assistant_messages = [msg for msg in result if msg.message_type == "assistant_message"] - reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"] - assert len(assistant_messages) == 2 - assert len(reasoning_messages) == 2 - for msg in assistant_messages: - assert msg.content == "test" - for msg in reasoning_messages: - assert "Test message" in msg.reasoning - - -# ====================================================================================================================== -# JobManager Tests - Usage Statistics - -# ====================================================================================================================== - -# TODO: add these back after runs refactor - -# @pytest.mark.asyncio -# async def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user): -# """Test adding and retrieving job usage statistics.""" -# job_manager = server.job_manager -# step_manager = server.step_manager -# -# # Add usage statistics -# await step_manager.log_step_async( -# agent_id=sarah_agent.id, -# provider_name="openai", -# provider_category="base", -# model="gpt-4o-mini", -# model_endpoint="https://api.openai.com/v1", -# context_window_limit=8192, -# job_id=default_job.id, -# usage=UsageStatistics( -# completion_tokens=100, -# prompt_tokens=50, -# total_tokens=150, -# ), -# actor=default_user, -# project_id=sarah_agent.project_id, -# ) -# -# # Get usage statistics -# usage_stats = await job_manager.get_job_usage_async(job_id=default_job.id, actor=default_user) -# -# # Verify the statistics -# assert usage_stats.completion_tokens == 100 -# assert usage_stats.prompt_tokens == 50 -# assert usage_stats.total_tokens == 150 -# -# # get steps -# steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) -# assert len(steps) == 1 -# -# -# @pytest.mark.asyncio -# async def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_user): -# """Test getting usage statistics for a job with no stats.""" -# job_manager = server.job_manager -# -# # Get usage statistics for a job with no stats -# usage_stats = await job_manager.get_job_usage(job_id=default_job.id, actor=default_user) -# -# # Verify default values -# assert usage_stats.completion_tokens == 0 -# assert usage_stats.prompt_tokens == 0 -# assert usage_stats.total_tokens == 0 -# -# # get steps -# steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) -# assert len(steps) == 0 -# -# -# @pytest.mark.asyncio -# async def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user): -# """Test adding multiple usage statistics entries for a job.""" -# job_manager = server.job_manager -# step_manager = server.step_manager -# -# # Add first usage statistics entry -# await step_manager.log_step_async( -# agent_id=sarah_agent.id, -# provider_name="openai", -# provider_category="base", -# model="gpt-4o-mini", -# model_endpoint="https://api.openai.com/v1", -# context_window_limit=8192, -# job_id=default_job.id, -# usage=UsageStatistics( -# completion_tokens=100, -# prompt_tokens=50, -# total_tokens=150, -# ), -# actor=default_user, -# project_id=sarah_agent.project_id, -# ) -# -# # Add second usage statistics entry -# await step_manager.log_step_async( -# agent_id=sarah_agent.id, -# provider_name="openai", -# provider_category="base", -# model="gpt-4o-mini", -# model_endpoint="https://api.openai.com/v1", -# context_window_limit=8192, -# job_id=default_job.id, -# usage=UsageStatistics( -# completion_tokens=200, -# prompt_tokens=100, -# total_tokens=300, -# ), -# actor=default_user, -# project_id=sarah_agent.project_id, -# ) -# -# # Get usage statistics (should return the latest entry) -# usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user) -# -# # Verify we get the most recent statistics -# assert usage_stats.completion_tokens == 300 -# assert usage_stats.prompt_tokens == 150 -# assert usage_stats.total_tokens == 450 -# assert usage_stats.step_count == 2 -# -# # get steps -# steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) -# assert len(steps) == 2 -# -# # get agent steps -# steps = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user) -# assert len(steps) == 2 -# -# # add step feedback -# step_manager = server.step_manager -# -# # Add feedback to first step -# await step_manager.add_feedback_async(step_id=steps[0].id, feedback=FeedbackType.POSITIVE, actor=default_user) -# -# # Test has_feedback filtering -# steps_with_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, has_feedback=True, actor=default_user) -# assert len(steps_with_feedback) == 1 -# -# steps_without_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user) -# assert len(steps_without_feedback) == 2 -# - - -# @pytest.mark.asyncio -# async def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): -# """Test getting usage statistics for a nonexistent job.""" -# job_manager = server.job_manager -# -# with pytest.raises(NoResultFound): -# job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user) @pytest.mark.asyncio diff --git a/tests/managers/test_message_manager.py b/tests/managers/test_message_manager.py index cd1a6e02..8a4c12fd 100644 --- a/tests/managers/test_message_manager.py +++ b/tests/managers/test_message_manager.py @@ -296,7 +296,7 @@ async def test_modify_letta_message(server: SyncServer, sarah_agent, default_use Test updating a message. """ - messages = await server.message_manager.list_messages_for_agent_async(agent_id=sarah_agent.id, actor=default_user) + messages = await server.message_manager.list_messages(agent_id=sarah_agent.id, actor=default_user) letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages) system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0] diff --git a/tests/managers/test_run_manager.py b/tests/managers/test_run_manager.py new file mode 100644 index 00000000..324f49fd --- /dev/null +++ b/tests/managers/test_run_manager.py @@ -0,0 +1,1104 @@ +import json +import logging +import os +import random +import re +import string +import time +import uuid +from datetime import datetime, timedelta, timezone +from typing import List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from _pytest.python_api import approx +from anthropic.types.beta import BetaMessage +from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult + +# Import shared fixtures and constants from conftest +from conftest import ( + CREATE_DELAY_SQLITE, + DEFAULT_EMBEDDING_CONFIG, + USING_SQLITE, +) +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction +from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError, InvalidRequestError +from sqlalchemy.orm.exc import StaleDataError + +from letta.config import LettaConfig +from letta.constants import ( + BASE_MEMORY_TOOLS, + BASE_SLEEPTIME_TOOLS, + BASE_TOOLS, + BASE_VOICE_SLEEPTIME_CHAT_TOOLS, + BASE_VOICE_SLEEPTIME_TOOLS, + BUILTIN_TOOLS, + DEFAULT_ORG_ID, + DEFAULT_ORG_NAME, + FILES_TOOLS, + LETTA_TOOL_EXECUTION_DIR, + LETTA_TOOL_SET, + LOCAL_ONLY_MULTI_AGENT_TOOLS, + MCP_TOOL_TAG_NAME_PREFIX, + MULTI_AGENT_TOOLS, +) +from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client +from letta.errors import LettaAgentNotFoundError +from letta.functions.functions import derive_openai_json_schema, parse_source_code +from letta.functions.mcp_client.types import MCPTool +from letta.helpers import ToolRulesSolver +from letta.helpers.datetime_helpers import AsyncTimer +from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo +from letta.orm import Base, Block +from letta.orm.block_history import BlockHistory +from letta.orm.errors import NoResultFound, UniqueConstraintViolationError +from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel +from letta.schemas.agent import CreateAgent, UpdateAgent +from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ( + ActorType, + AgentStepStatus, + FileProcessingStatus, + JobStatus, + JobType, + MessageRole, + ProviderType, + RunStatus, + SandboxType, + StepStatus, + TagMatchMode, + ToolType, + VectorDBProvider, +) +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate +from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata +from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert +from letta.schemas.job import Job as PydanticJob, LettaRequestConfig +from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage +from letta.schemas.letta_message_content import TextContent +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate +from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.organization import Organization, Organization as PydanticOrganization, OrganizationUpdate +from letta.schemas.passage import Passage as PydanticPassage +from letta.schemas.pip_requirement import PipRequirement +from letta.schemas.run import Run as PydanticRun, RunUpdate +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.source import Source as PydanticSource, SourceUpdate +from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate +from letta.schemas.tool_rule import InitToolRule +from letta.schemas.user import User as PydanticUser, UserUpdate +from letta.server.db import db_registry +from letta.server.server import SyncServer +from letta.services.block_manager import BlockManager +from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async +from letta.services.step_manager import FeedbackType +from letta.settings import settings, tool_settings +from letta.utils import calculate_file_defaults_based_on_context_window +from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview +from tests.utils import random_string + +# ====================================================================================================================== +# RunManager Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_create_run(server: SyncServer, sarah_agent, default_user): + """Test creating a run.""" + run_data = PydanticRun( + metadata={"type": "test"}, + agent_id=sarah_agent.id, + ) + + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Assertions to ensure the created run matches the expected values + assert created_run.agent_id == sarah_agent.id + assert created_run.created_at + assert created_run.status == RunStatus.created + assert created_run.metadata == {"type": "test"} + + +@pytest.mark.asyncio +async def test_get_run_by_id(server: SyncServer, sarah_agent, default_user): + """Test fetching a run by ID.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Fetch the run by ID + fetched_run = await server.run_manager.get_run_by_id(created_run.id, actor=default_user) + + # Assertions to ensure the fetched run matches the created run + assert fetched_run.id == created_run.id + assert fetched_run.status == RunStatus.created + assert fetched_run.metadata == {"type": "test"} + + +@pytest.mark.asyncio +async def test_list_runs(server: SyncServer, sarah_agent, default_user): + """Test listing runs.""" + # Create multiple runs + for i in range(3): + run_data = PydanticRun( + metadata={"type": f"test-{i}"}, + agent_id=sarah_agent.id, + ) + await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # List runs + runs = await server.run_manager.list_runs(actor=default_user) + + # Assertions to check that the created runs are listed + assert len(runs) == 3 + assert all(run.agent_id == sarah_agent.id for run in runs) + assert all(run.metadata["type"].startswith("test") for run in runs) + + +@pytest.mark.asyncio +async def test_list_runs_with_metadata(server: SyncServer, sarah_agent, default_user): + for i in range(3): + run_data = PydanticRun(agent_id=sarah_agent.id) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + if i == 1: + await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.completed), actor=default_user) + + runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.completed]) + assert len(runs) == 1 + assert runs[0].status == RunStatus.completed + + runs = await server.run_manager.list_runs(actor=default_user) + assert len(runs) == 3 + + +@pytest.mark.asyncio +async def test_update_run_by_id(server: SyncServer, sarah_agent, default_user): + """Test updating a run by its ID.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Update the run + updated_run = await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.completed), actor=default_user) + + # Assertions to ensure the run was updated + assert updated_run.status == RunStatus.completed + + +@pytest.mark.asyncio +async def test_delete_run_by_id(server: SyncServer, sarah_agent, default_user): + """Test deleting a run by its ID.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + print("created_run to delete", created_run.id) + + # Delete the run + await server.run_manager.delete_run(created_run.id, actor=default_user) + + # Fetch the run by ID + with pytest.raises(NoResultFound): + await server.run_manager.get_run_by_id(created_run.id, actor=default_user) + + # List runs to ensure the run was deleted + runs = await server.run_manager.list_runs(actor=default_user) + assert len(runs) == 0 + + +@pytest.mark.asyncio +async def test_update_run_auto_complete(server: SyncServer, default_user, sarah_agent): + """Test that updating a run's status to 'completed' automatically sets completed_at.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + assert created_run.completed_at is None + + # Update the run to completed status + updated_run = await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.completed), actor=default_user) + + # Check that completed_at was automatically set + assert updated_run.completed_at is not None + assert isinstance(updated_run.completed_at, datetime) + + +@pytest.mark.asyncio +async def test_get_run_not_found(server: SyncServer, default_user): + """Test fetching a non-existent run.""" + non_existent_run_id = "nonexistent-id" + with pytest.raises(NoResultFound): + await server.run_manager.get_run_by_id(non_existent_run_id, actor=default_user) + + +@pytest.mark.asyncio +async def test_delete_run_not_found(server: SyncServer, default_user): + """Test deleting a non-existent run.""" + non_existent_run_id = "nonexistent-id" + with pytest.raises(NoResultFound): + await server.run_manager.delete_run(non_existent_run_id, actor=default_user) + + +@pytest.mark.asyncio +async def test_list_runs_pagination(server: SyncServer, sarah_agent, default_user): + """Test listing runs with pagination.""" + # Create multiple runs + for i in range(10): + run_data = PydanticRun(agent_id=sarah_agent.id) + await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # List runs with a limit + runs = await server.run_manager.list_runs(actor=default_user, limit=5) + assert len(runs) == 5 + assert all(run.agent_id == sarah_agent.id for run in runs) + + # Test cursor-based pagination + first_page = await server.run_manager.list_runs(actor=default_user, limit=3, ascending=True) # [J0, J1, J2] + assert len(first_page) == 3 + assert first_page[0].created_at <= first_page[1].created_at <= first_page[2].created_at + + last_page = await server.run_manager.list_runs(actor=default_user, limit=3, ascending=False) # [J9, J8, J7] + assert len(last_page) == 3 + assert last_page[0].created_at >= last_page[1].created_at >= last_page[2].created_at + first_page_ids = set(run.id for run in first_page) + last_page_ids = set(run.id for run in last_page) + assert first_page_ids.isdisjoint(last_page_ids) + + # Test middle page using both before and after + middle_page = await server.run_manager.list_runs( + actor=default_user, before=last_page[-1].id, after=first_page[-1].id, ascending=True + ) # [J3, J4, J5, J6] + assert len(middle_page) == 4 # Should include jobs between first and second page + head_tail_jobs = first_page_ids.union(last_page_ids) + assert all(job.id not in head_tail_jobs for job in middle_page) + + # NOTE: made some changes about assumptions ofr ascending + + # Test descending order + middle_page_desc = await server.run_manager.list_runs( + # actor=default_user, before=last_page[-1].id, after=first_page[-1].id, ascending=False + actor=default_user, + before=first_page[-1].id, + after=last_page[-1].id, + ascending=False, + ) # [J6, J5, J4, J3] + assert len(middle_page_desc) == 4 + assert middle_page_desc[0].id == middle_page[-1].id + assert middle_page_desc[1].id == middle_page[-2].id + assert middle_page_desc[2].id == middle_page[-3].id + assert middle_page_desc[3].id == middle_page[-4].id + + # BONUS + run_7 = last_page[-1].id + # earliest_runs = await server.run_manager.list_runs(actor=default_user, ascending=False, before=run_7) + earliest_runs = await server.run_manager.list_runs(actor=default_user, ascending=True, before=run_7) + assert len(earliest_runs) == 7 + assert all(j.id not in last_page_ids for j in earliest_runs) + # assert all(earliest_runs[i].created_at >= earliest_runs[i + 1].created_at for i in range(len(earliest_runs) - 1)) + assert all(earliest_runs[i].created_at <= earliest_runs[i + 1].created_at for i in range(len(earliest_runs) - 1)) + + +@pytest.mark.asyncio +async def test_list_runs_by_status(server: SyncServer, default_user, sarah_agent): + """Test listing runs filtered by status.""" + # Create multiple runs with different statuses + run_data_created = PydanticRun( + status=RunStatus.created, + metadata={"type": "test-created"}, + agent_id=sarah_agent.id, + ) + run_data_in_progress = PydanticRun( + status=RunStatus.running, + metadata={"type": "test-running"}, + agent_id=sarah_agent.id, + ) + run_data_completed = PydanticRun( + status=RunStatus.completed, + metadata={"type": "test-completed"}, + agent_id=sarah_agent.id, + ) + + await server.run_manager.create_run(pydantic_run=run_data_created, actor=default_user) + await server.run_manager.create_run(pydantic_run=run_data_in_progress, actor=default_user) + await server.run_manager.create_run(pydantic_run=run_data_completed, actor=default_user) + + # List runs filtered by status + created_runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.created]) + in_progress_runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.running]) + completed_runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.completed]) + + # Assertions + assert len(created_runs) == 1 + assert created_runs[0].metadata["type"] == run_data_created.metadata["type"] + + assert len(in_progress_runs) == 1 + assert in_progress_runs[0].metadata["type"] == run_data_in_progress.metadata["type"] + + assert len(completed_runs) == 1 + assert completed_runs[0].metadata["type"] == run_data_completed.metadata["type"] + + +@pytest.mark.asyncio +async def test_list_runs_by_stop_reason(server: SyncServer, sarah_agent, default_user): + """Test listing runs by stop reason.""" + + run_pydantic = PydanticRun( + agent_id=sarah_agent.id, + stop_reason=StopReasonType.requires_approval, + background=True, + ) + run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user) + assert run.stop_reason == StopReasonType.requires_approval + assert run.background == True + assert run.agent_id == sarah_agent.id + + # list runs by stop reason + runs = await server.run_manager.list_runs(actor=default_user, stop_reason=StopReasonType.requires_approval) + assert len(runs) == 1 + assert runs[0].id == run.id + + # list runs by background + runs = await server.run_manager.list_runs(actor=default_user, background=True) + assert len(runs) == 1 + assert runs[0].id == run.id + + # list runs by agent_id + runs = await server.run_manager.list_runs(actor=default_user, agent_ids=[sarah_agent.id]) + assert len(runs) == 1 + assert runs[0].id == run.id + + +async def test_e2e_run_callback(monkeypatch, server: SyncServer, default_user, sarah_agent): + """Test that run callbacks are properly dispatched when a run is completed.""" + captured = {} + + # Create a simple mock for the async HTTP client + class MockAsyncResponse: + status_code = 202 + + async def mock_post(url, json, timeout): + captured["url"] = url + captured["json"] = json + return MockAsyncResponse() + + class MockAsyncClient: + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def post(self, url, json, timeout): + return await mock_post(url, json, timeout) + + # Patch the AsyncClient + import letta.services.run_manager as run_manager_module + + monkeypatch.setattr(run_manager_module, "AsyncClient", MockAsyncClient) + + run_in = PydanticRun( + status=RunStatus.created, metadata={"foo": "bar"}, agent_id=sarah_agent.id, callback_url="http://example.test/webhook/runs" + ) + created = await server.run_manager.create_run(pydantic_run=run_in, actor=default_user) + assert created.callback_url == "http://example.test/webhook/runs" + + # Update the run status to completed, which should trigger the callback + updated = await server.run_manager.update_run_by_id_async( + created.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user + ) + + # Verify the callback was triggered with the correct parameters + assert captured["url"] == created.callback_url, "Callback URL doesn't match" + assert captured["json"]["run_id"] == created.id, "Run ID in callback doesn't match" + assert captured["json"]["status"] == RunStatus.completed.value, "Run status in callback doesn't match" + + # Verify the completed_at timestamp is reasonable + actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None) + # Remove timezone from updated.completed_at for comparison (it comes from DB as timezone-aware) + assert abs((actual_dt - updated.completed_at).total_seconds()) < 1, "Timestamp difference is too large" + + assert isinstance(updated.callback_sent_at, datetime) + assert updated.callback_status_code == 202 + + +@pytest.mark.asyncio +async def test_run_callback_only_on_terminal_status(server: SyncServer, sarah_agent, default_user, monkeypatch): + """ + Regression: ensure a non-terminal update (running) does NOT set completed_at or trigger callback, + and that a subsequent terminal update (completed) does trigger the callback exactly once. + """ + + # Capture callback invocations + captured = {"count": 0, "url": None, "json": None} + + class MockAsyncResponse: + status_code = 202 + + async def mock_post(url, json, timeout): + captured["count"] += 1 + captured["url"] = url + captured["json"] = json + return MockAsyncResponse() + + class MockAsyncClient: + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def post(self, url, json, timeout): + return await mock_post(url, json, timeout) + + # Patch the AsyncClient in run_manager module + import letta.services.run_manager as run_manager_module + + monkeypatch.setattr(run_manager_module, "AsyncClient", MockAsyncClient) + + # Create run with a callback URL + run_in = PydanticRun( + status=RunStatus.created, + metadata={"foo": "bar"}, + agent_id=sarah_agent.id, + callback_url="http://example.test/webhook/runs", + ) + created = await server.run_manager.create_run(pydantic_run=run_in, actor=default_user) + assert created.callback_url == "http://example.test/webhook/runs" + + # 1) Non-terminal update: running + updated_running = await server.run_manager.update_run_by_id_async(created.id, RunUpdate(status=RunStatus.running), actor=default_user) + + # Should not set completed_at or trigger callback + assert updated_running.completed_at is None + assert captured["count"] == 0 + + # 2) Terminal update: completed + updated_completed = await server.run_manager.update_run_by_id_async( + created.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user + ) + + # Should trigger exactly one callback with expected payload + assert captured["count"] == 1 + assert captured["url"] == created.callback_url + assert captured["json"]["run_id"] == created.id + assert captured["json"]["status"] == RunStatus.completed.value + + # completed_at should be set and align closely with callback payload + assert updated_completed.completed_at is not None + actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None) + assert abs((actual_dt - updated_completed.completed_at).total_seconds()) < 1 + + assert isinstance(updated_completed.callback_sent_at, datetime) + assert updated_completed.callback_status_code == 202 + + +# ====================================================================================================================== +# RunManager Tests - Messages +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_run_messages_pagination(server: SyncServer, default_run, default_user, sarah_agent): + """Test pagination of run messages.""" + + # create the run + run_pydantic = PydanticRun( + agent_id=sarah_agent.id, + status=RunStatus.created, + metadata={"foo": "bar"}, + ) + run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user) + assert run.status == RunStatus.created + + # Create multiple messages + message_ids = [] + for i in range(5): + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Test message {i}")], + run_id=run.id, + ) + msg = await server.message_manager.create_many_messages_async([message], actor=default_user) + message_ids.append(msg[0].id) + + # Test pagination with limit + messages = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + limit=2, + ) + assert len(messages) == 2 + assert messages[0].id == message_ids[0] + assert messages[1].id == message_ids[1] + + # Test pagination with cursor + first_page = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + limit=2, + ascending=True, # [M0, M1] + ) + assert len(first_page) == 2 + assert first_page[0].id == message_ids[0] + assert first_page[1].id == message_ids[1] + assert first_page[0].created_at <= first_page[1].created_at + + last_page = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + limit=2, + ascending=False, # [M4, M3] + ) + assert len(last_page) == 2 + assert last_page[0].id == message_ids[4] + assert last_page[1].id == message_ids[3] + assert last_page[0].created_at >= last_page[1].created_at + + first_page_ids = set(msg.id for msg in first_page) + last_page_ids = set(msg.id for msg in last_page) + assert first_page_ids.isdisjoint(last_page_ids) + + # Test middle page using both before and after + middle_page = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + before=last_page[-1].id, # M3 + after=first_page[0].id, # M0 + ascending=True, # [M1, M2] + ) + assert len(middle_page) == 2 # Should include message between first and last pages + assert middle_page[0].id == message_ids[1] + assert middle_page[1].id == message_ids[2] + head_tail_msgs = first_page_ids.union(last_page_ids) + assert middle_page[1].id not in head_tail_msgs + assert middle_page[0].id in first_page_ids + + # Test descending order for middle page + middle_page = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + before=last_page[-1].id, # M3 + after=first_page[0].id, # M0 + ascending=False, # [M2, M1] + ) + assert len(middle_page) == 2 # Should include message between first and last pages + assert middle_page[0].id == message_ids[2] + assert middle_page[1].id == message_ids[1] + + # Test getting earliest messages + msg_3 = last_page[-1].id + earliest_msgs = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + ascending=False, + before=msg_3, # Get messages after M3 in descending order + ) + assert len(earliest_msgs) == 3 # Should get M2, M1, M0 + assert all(m.id not in last_page_ids for m in earliest_msgs) + assert earliest_msgs[0].created_at > earliest_msgs[1].created_at > earliest_msgs[2].created_at + + # Test getting earliest messages with ascending order + earliest_msgs_ascending = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + ascending=True, + before=msg_3, # Get messages before M3 in ascending order + ) + assert len(earliest_msgs_ascending) == 3 # Should get M0, M1, M2 + assert all(m.id not in last_page_ids for m in earliest_msgs_ascending) + assert earliest_msgs_ascending[0].created_at < earliest_msgs_ascending[1].created_at < earliest_msgs_ascending[2].created_at + + +@pytest.mark.asyncio +async def test_run_messages_ordering(server: SyncServer, default_run, default_user, sarah_agent): + """Test that messages are ordered by created_at.""" + # Create messages with different timestamps + base_time = datetime.now(timezone.utc) + message_times = [ + base_time - timedelta(minutes=2), + base_time - timedelta(minutes=1), + base_time, + ] + + # create the run + run_pydantic = PydanticRun( + agent_id=sarah_agent.id, + ) + run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user) + assert run.status == RunStatus.created + + for i, created_at in enumerate(message_times): + message = PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Test message")], + agent_id=sarah_agent.id, + created_at=created_at, + run_id=run.id, + ) + msg = await server.message_manager.create_many_messages_async([message], actor=default_user) + + # Verify messages are returned in chronological order + returned_messages = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + ) + + assert len(returned_messages) == 3 + assert returned_messages[0].created_at < returned_messages[1].created_at + assert returned_messages[1].created_at < returned_messages[2].created_at + + # Verify messages are returned in descending order + returned_messages = await server.message_manager.list_messages( + run_id=run.id, + actor=default_user, + ascending=False, + ) + + assert len(returned_messages) == 3 + assert returned_messages[0].created_at > returned_messages[1].created_at + assert returned_messages[1].created_at > returned_messages[2].created_at + + +@pytest.mark.asyncio +async def test_job_messages_empty(server: SyncServer, default_run, default_user): + """Test getting messages for a job with no messages.""" + messages = await server.message_manager.list_messages( + run_id=default_run.id, + actor=default_user, + ) + assert len(messages) == 0 + + +@pytest.mark.asyncio +async def test_job_messages_filter(server: SyncServer, default_run, default_user, sarah_agent): + """Test getting messages associated with a job.""" + # Create the run + run_pydantic = PydanticRun( + agent_id=sarah_agent.id, + ) + run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user) + assert run.status == RunStatus.created + + # Create test messages with different roles and tool calls + messages = [ + PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Hello")], + agent_id=sarah_agent.id, + run_id=default_run.id, + ), + PydanticMessage( + role=MessageRole.assistant, + content=[TextContent(text="Hi there!")], + agent_id=sarah_agent.id, + run_id=default_run.id, + ), + PydanticMessage( + role=MessageRole.assistant, + content=[TextContent(text="Let me help you with that")], + agent_id=sarah_agent.id, + tool_calls=[ + OpenAIToolCall( + id="call_1", + type="function", + function=OpenAIFunction( + name="test_tool", + arguments='{"arg1": "value1"}', + ), + ) + ], + run_id=default_run.id, + ), + ] + await server.message_manager.create_many_messages_async(messages, actor=default_user) + + # Test getting all messages + all_messages = await server.message_manager.list_messages( + run_id=default_run.id, + actor=default_user, + ) + assert len(all_messages) == 3 + + # Test filtering by role + user_messages = await server.message_manager.list_messages(run_id=default_run.id, actor=default_user, roles=[MessageRole.user]) + assert len(user_messages) == 1 + assert user_messages[0].role == MessageRole.user + + # Test limit + limited_messages = await server.message_manager.list_messages(run_id=default_run.id, actor=default_user, limit=2) + assert len(limited_messages) == 2 + + +@pytest.mark.asyncio +async def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent): + """Test getting messages for a run with request config.""" + # Create a run with custom request config + run = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=sarah_agent.id, + status=RunStatus.created, + request_config=LettaRequestConfig( + use_assistant_message=False, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" + ), + ), + actor=default_user, + ) + + # Add some messages + messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, + content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], + tool_calls=( + [{"type": "function", "id": f"call_{i // 2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] + if i % 2 == 1 + else None + ), + tool_call_id=f"call_{i // 2}" if i % 2 == 0 else None, + run_id=run.id, + ) + for i in range(4) + ] + + created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user) + + # Get messages and verify they're converted correctly + result = await server.message_manager.list_messages(run_id=run.id, actor=default_user) + result = Message.to_letta_messages_from_list(result) + + # Verify correct number of messages. Assistant messages should be parsed + assert len(result) == 6 + + # Verify assistant messages are parsed according to request config + tool_call_messages = [msg for msg in result if msg.message_type == "tool_call_message"] + reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"] + assert len(tool_call_messages) == 2 + assert len(reasoning_messages) == 2 + for msg in tool_call_messages: + assert msg.tool_call is not None + assert msg.tool_call.name == "custom_tool" + + +@pytest.mark.asyncio +async def test_get_run_messages_with_assistant_message(server: SyncServer, default_user: PydanticUser, sarah_agent): + """Test getting messages for a run with request config.""" + # Create a run with custom request config + run = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=sarah_agent.id, + status=RunStatus.created, + request_config=LettaRequestConfig( + use_assistant_message=True, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" + ), + ), + actor=default_user, + ) + + # Add some messages + messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, + content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], + tool_calls=( + [{"type": "function", "id": f"call_{i // 2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] + if i % 2 == 1 + else None + ), + tool_call_id=f"call_{i // 2}" if i % 2 == 0 else None, + run_id=run.id, + ) + for i in range(4) + ] + + created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user) + + # Get messages and verify they're converted correctly + result = await server.message_manager.list_messages(run_id=run.id, actor=default_user) + result = Message.to_letta_messages_from_list( + result, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" + ) + + # Verify correct number of messages. Assistant messages should be parsed + assert len(result) == 4 + + # Verify assistant messages are parsed according to request config + assistant_messages = [msg for msg in result if msg.message_type == "assistant_message"] + reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"] + assert len(assistant_messages) == 2 + assert len(reasoning_messages) == 2 + for msg in assistant_messages: + assert msg.content == "test" + for msg in reasoning_messages: + assert "Test message" in msg.reasoning + + +# ====================================================================================================================== +# RunManager Tests - Usage Statistics - +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_run_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_run, default_user): + """Test adding and retrieving run usage statistics.""" + run_manager = server.run_manager + step_manager = server.step_manager + + # Add usage statistics + await step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + run_id=default_run.id, + usage=UsageStatistics( + completion_tokens=100, + prompt_tokens=50, + total_tokens=150, + ), + actor=default_user, + project_id=sarah_agent.project_id, + ) + + # Get usage statistics + usage_stats = await run_manager.get_run_usage(run_id=default_run.id, actor=default_user) + + # Verify the statistics + assert usage_stats.completion_tokens == 100 + assert usage_stats.prompt_tokens == 50 + assert usage_stats.total_tokens == 150 + + # get steps + steps = await step_manager.list_steps_async(run_id=default_run.id, actor=default_user) + assert len(steps) == 1 + + +@pytest.mark.asyncio +async def test_run_usage_stats_get_no_stats(server: SyncServer, default_run, default_user): + """Test getting usage statistics for a job with no stats.""" + run_manager = server.run_manager + + # Get usage statistics for a job with no stats + usage_stats = await run_manager.get_run_usage(run_id=default_run.id, actor=default_user) + + # Verify default values + assert usage_stats.completion_tokens == 0 + assert usage_stats.prompt_tokens == 0 + assert usage_stats.total_tokens == 0 + + # get steps + steps = await server.step_manager.list_steps_async(run_id=default_run.id, actor=default_user) + assert len(steps) == 0 + + +@pytest.mark.asyncio +async def test_run_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_run, default_user): + """Test adding multiple usage statistics entries for a job.""" + run_manager = server.run_manager + step_manager = server.step_manager + + # Add first usage statistics entry + await step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=100, + prompt_tokens=50, + total_tokens=150, + ), + actor=default_user, + project_id=sarah_agent.project_id, + run_id=default_run.id, + ) + + # Add second usage statistics entry + await step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=200, + prompt_tokens=100, + total_tokens=300, + ), + actor=default_user, + project_id=sarah_agent.project_id, + run_id=default_run.id, + ) + + # Get usage statistics (should return the latest entry) + usage_stats = await run_manager.get_run_usage(run_id=default_run.id, actor=default_user) + + # Verify we get the most recent statistics + assert usage_stats.completion_tokens == 300 + assert usage_stats.prompt_tokens == 150 + assert usage_stats.total_tokens == 450 + assert usage_stats.step_count == 2 + + # get steps + steps = await step_manager.list_steps_async(run_id=default_run.id, actor=default_user) + assert len(steps) == 2 + + # get agent steps + steps = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user) + assert len(steps) == 2 + + # add step feedback + step_manager = server.step_manager + + # Add feedback to first step + await step_manager.add_feedback_async(step_id=steps[0].id, feedback=FeedbackType.POSITIVE, actor=default_user) + + # Test has_feedback filtering + steps_with_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, has_feedback=True, actor=default_user) + assert len(steps_with_feedback) == 1 + + steps_without_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user) + assert len(steps_without_feedback) == 2 + + +@pytest.mark.asyncio +async def test_run_usage_stats_get_nonexistent_run(server: SyncServer, default_user): + """Test getting usage statistics for a nonexistent run.""" + run_manager = server.run_manager + + with pytest.raises(NoResultFound): + await run_manager.get_run_usage(run_id="nonexistent_run", actor=default_user) + + +@pytest.mark.asyncio +async def test_get_run_request_config(server: SyncServer, sarah_agent, default_user): + """Test getting request config from a run.""" + request_config = LettaRequestConfig( + use_assistant_message=True, assistant_message_tool_name="send_message", assistant_message_tool_kwarg="message" + ) + + run_data = PydanticRun( + agent_id=sarah_agent.id, + request_config=request_config, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + retrieved_config = await server.run_manager.get_run_request_config(created_run.id, actor=default_user) + + assert retrieved_config is not None + assert retrieved_config.use_assistant_message == request_config.use_assistant_message + assert retrieved_config.assistant_message_tool_name == request_config.assistant_message_tool_name + assert retrieved_config.assistant_message_tool_kwarg == request_config.assistant_message_tool_kwarg + + +@pytest.mark.asyncio +async def test_get_run_request_config_none(server: SyncServer, sarah_agent, default_user): + """Test getting request config from a run with no config.""" + run_data = PydanticRun(agent_id=sarah_agent.id) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + retrieved_config = await server.run_manager.get_run_request_config(created_run.id, actor=default_user) + + assert retrieved_config is None + + +@pytest.mark.asyncio +async def test_get_run_request_config_nonexistent_run(server: SyncServer, default_user): + """Test getting request config for a nonexistent run.""" + with pytest.raises(NoResultFound): + await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user) + + +# TODO: add back once metrics are added + +# @pytest.mark.asyncio +# async def test_record_ttft(server: SyncServer, default_user): +# """Test recording time to first token for a job.""" +# # Create a job +# job_data = PydanticJob( +# status=RunStatus.created, +# metadata={"type": "test_timing"}, +# ) +# created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) +# +# # Record TTFT +# ttft_ns = 1_500_000_000 # 1.5 seconds in nanoseconds +# await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user) +# +# # Fetch the job and verify TTFT was recorded +# updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user) +# assert updated_job.ttft_ns == ttft_ns +# +# +# @pytest.mark.asyncio +# async def test_record_response_duration(server: SyncServer, default_user): +# """Test recording total response duration for a job.""" +# # Create a job +# job_data = PydanticJob( +# status=RunStatus.created, +# metadata={"type": "test_timing"}, +# ) +# created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) +# +# # Record response duration +# duration_ns = 5_000_000_000 # 5 seconds in nanoseconds +# await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user) +# +# # Fetch the job and verify duration was recorded +# updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user) +# assert updated_job.total_duration_ns == duration_ns +# +# +# @pytest.mark.asyncio +# async def test_record_timing_metrics_together(server: SyncServer, default_user): +# """Test recording both TTFT and response duration for a job.""" +# # Create a job +# job_data = PydanticJob( +# status=RunStatus.created, +# metadata={"type": "test_timing_combined"}, +# ) +# created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) +# +# # Record both metrics +# ttft_ns = 2_000_000_000 # 2 seconds in nanoseconds +# duration_ns = 8_500_000_000 # 8.5 seconds in nanoseconds +# +# await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user) +# await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user) +# +# # Fetch the job and verify both metrics were recorded +# updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user) +# assert updated_job.ttft_ns == ttft_ns +# assert updated_job.total_duration_ns == duration_ns +# +# +# @pytest.mark.asyncio +# async def test_record_timing_invalid_job(server: SyncServer, default_user): +# """Test recording timing metrics for non-existent job fails gracefully.""" +# # Try to record TTFT for non-existent job - should not raise exception but log warning +# await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user) +# +# # Try to record response duration for non-existent job - should not raise exception but log warning +# await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user) +# diff --git a/tests/test_agent_files/test_basic_agent_with_blocks_tools_messages_v2.af b/tests/test_agent_files/test_basic_agent_with_blocks_tools_messages_v2.af index 5d1142e5..48b3fd02 100644 --- a/tests/test_agent_files/test_basic_agent_with_blocks_tools_messages_v2.af +++ b/tests/test_agent_files/test_basic_agent_with_blocks_tools_messages_v2.af @@ -1,7 +1,7 @@ { "agents": [ { - "name": "test_export_import_431ac32f-ffd1-40a7-8152-9733470c951d", + "name": "test_export_import_fb3c857d-b25b-4666-9197-3966a2458cb0", "memory_blocks": [], "tools": [], "tool_ids": [ @@ -31,12 +31,12 @@ "prompt_template": null }, { - "tool_name": "memory_insert", + "tool_name": "conversation_search", "type": "continue_loop", "prompt_template": null }, { - "tool_name": "conversation_search", + "tool_name": "memory_insert", "type": "continue_loop", "prompt_template": null } @@ -129,7 +129,7 @@ "content": [ { "type": "text", - "text": "You are a helpful assistant specializing in data analysis and mathematical computations.\n\n\nThe following memory blocks are currently engaged in your core memory unit:\n\n\n\nThe persona block: Stores details about your current persona, guiding how you behave and respond. This helps you to maintain consistency and personality in your interactions.\n\n\n- chars_current=195\n- chars_limit=8000\n\n\n# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\nLine 1: You are Alex, a data analyst and mathematician who helps users with calculations and insights. You have extensive experience in statistical analysis and prefer to provide clear, accurate results.\n\n\n\n\n\nThe human block: Stores key details about the person you are conversing with, allowing for more personalized and friend-like conversation.\n\n\n- chars_current=175\n- chars_limit=4000\n\n\n# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\nLine 1: username: sarah_researcher\nLine 2: occupation: data scientist\nLine 3: interests: machine learning, statistics, fibonacci sequences\nLine 4: preferred_communication: detailed explanations with examples\n\n\n\n\n\n\n\n\n- chars_current=210\n- chars_limit=6000\n\n\n# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\nLine 1: Current project: Building predictive models for financial markets. Sarah is working on sequence analysis and pattern recognition. Recently interested in mathematical sequences like Fibonacci for trend analysis.\n\n\n\n\n\n\nThe following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls.\n\n\nmemory_replace requires continuing your response when called\n\n\nmemory_insert requires continuing your response when called\n\n\nconversation_search requires continuing your response when called\n\n\nsend_message ends your response (yields control) when called\n\n\n\n\n- The current system date is: September 24, 2025\n- Memory blocks were last modified: 2025-09-24 10:57:40 PM UTC+0000\n- -1 previous messages between you and the user are stored in recall memory (use tools to access them)\n- 2 total memories you created are stored in archival memory (use tools to access them)\n" + "text": "You are a helpful assistant specializing in data analysis and mathematical computations.\n\n\nThe following memory blocks are currently engaged in your core memory unit:\n\n\n\nThe persona block: Stores details about your current persona, guiding how you behave and respond. This helps you to maintain consistency and personality in your interactions.\n\n\n- chars_current=195\n- chars_limit=8000\n\n\n# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\nLine 1: You are Alex, a data analyst and mathematician who helps users with calculations and insights. You have extensive experience in statistical analysis and prefer to provide clear, accurate results.\n\n\n\n\n\nThe human block: Stores key details about the person you are conversing with, allowing for more personalized and friend-like conversation.\n\n\n- chars_current=175\n- chars_limit=4000\n\n\n# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\nLine 1: username: sarah_researcher\nLine 2: occupation: data scientist\nLine 3: interests: machine learning, statistics, fibonacci sequences\nLine 4: preferred_communication: detailed explanations with examples\n\n\n\n\n\n\n\n\n- chars_current=210\n- chars_limit=6000\n\n\n# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\nLine 1: Current project: Building predictive models for financial markets. Sarah is working on sequence analysis and pattern recognition. Recently interested in mathematical sequences like Fibonacci for trend analysis.\n\n\n\n\n\n\nThe following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls.\n\n\nmemory_replace requires continuing your response when called\n\n\nconversation_search requires continuing your response when called\n\n\nmemory_insert requires continuing your response when called\n\n\nsend_message ends your response (yields control) when called\n\n\n\n\n- The current system date is: September 26, 2025\n- Memory blocks were last modified: 2025-09-26 05:01:19 AM UTC+0000\n- -1 previous messages between you and the user are stored in recall memory (use tools to access them)\n- 2 total memories you created are stored in archival memory (use tools to access them)\n" } ], "name": null, @@ -143,7 +143,7 @@ "tool_calls": null, "tool_call_id": null, "tool_returns": [], - "created_at": "2025-09-24T22:57:39.493431+00:00" + "created_at": "2025-09-26T05:01:17.251536+00:00" }, { "type": "message", @@ -164,7 +164,7 @@ "agent_id": "agent-0", "tool_calls": [ { - "id": "f1c5f8b4-c57c-4641-8d58-903248a31f7b", + "id": "d734f768-4f38-4dfb-993f-3b7220e3466c", "function": { "arguments": "{\n \"message\": \"More human than human is our motto.\"\n}", "name": "send_message" @@ -174,7 +174,7 @@ ], "tool_call_id": null, "tool_returns": [], - "created_at": "2025-09-24T22:57:39.493461+00:00" + "created_at": "2025-09-26T05:01:17.251576+00:00" }, { "type": "message", @@ -182,7 +182,7 @@ "content": [ { "type": "text", - "text": "{\n \"status\": \"OK\",\n \"message\": null,\n \"time\": \"2025-09-24 10:57:39 PM UTC+0000\"\n}" + "text": "{\n \"status\": \"OK\",\n \"message\": null,\n \"time\": \"2025-09-26 05:01:17 AM UTC+0000\"\n}" } ], "name": "send_message", @@ -194,9 +194,9 @@ "model": "gpt-4.1-mini", "agent_id": "agent-0", "tool_calls": null, - "tool_call_id": "f1c5f8b4-c57c-4641-8d58-903248a31f7b", + "tool_call_id": "d734f768-4f38-4dfb-993f-3b7220e3466c", "tool_returns": [], - "created_at": "2025-09-24T22:57:39.493485+00:00" + "created_at": "2025-09-26T05:01:17.251607+00:00" }, { "type": "message", @@ -204,7 +204,7 @@ "content": [ { "type": "text", - "text": "{\n \"type\": \"login\",\n \"last_login\": \"Never (first login)\",\n \"time\": \"2025-09-24 10:57:39 PM UTC+0000\"\n}" + "text": "{\n \"type\": \"login\",\n \"last_login\": \"Never (first login)\",\n \"time\": \"2025-09-26 05:01:17 AM UTC+0000\"\n}" } ], "name": null, @@ -218,7 +218,7 @@ "tool_calls": null, "tool_call_id": null, "tool_returns": [], - "created_at": "2025-09-24T22:57:39.493496+00:00" + "created_at": "2025-09-26T05:01:17.251623+00:00" }, { "type": "message", @@ -240,7 +240,7 @@ "tool_calls": null, "tool_call_id": null, "tool_returns": [], - "created_at": "2025-09-24T22:57:40.764277+00:00" + "created_at": "2025-09-26T05:01:18.840105+00:00" }, { "type": "message", @@ -248,7 +248,7 @@ "content": [ { "type": "text", - "text": "User sent a test message for export functionality verification." + "text": "Confirming receipt of test message from user." } ], "name": null, @@ -261,17 +261,17 @@ "agent_id": "agent-0", "tool_calls": [ { - "id": "call_pvBdfaX7NpaFcjrhqNCSyg7S", + "id": "call_ktYL78z2mTh3w1t6pWixMHqA", "function": { - "arguments": "{\"message\": \"I received your test message successfully. If you need any assistance or want to continue with our project discussions, just let me know!\", \"request_heartbeat\": false}", + "arguments": "{\"message\": \"Received your test message for export successfully. Let me know if you need any calculations or data analysis.\", \"request_heartbeat\": false}", "name": "send_message" }, "type": "function" } ], - "tool_call_id": "call_pvBdfaX7NpaFcjrhqNCSyg7S", + "tool_call_id": "call_ktYL78z2mTh3w1t6pWixMHqA", "tool_returns": [], - "created_at": "2025-09-24T22:57:41.948407+00:00" + "created_at": "2025-09-26T05:01:20.466363+00:00" }, { "type": "message", @@ -279,7 +279,7 @@ "content": [ { "type": "text", - "text": "{\n \"status\": \"OK\",\n \"message\": \"Sent message successfully.\",\n \"time\": \"2025-09-24 10:57:41 PM UTC+0000\"\n}" + "text": "{\n \"status\": \"OK\",\n \"message\": \"Sent message successfully.\",\n \"time\": \"2025-09-26 05:01:20 AM UTC+0000\"\n}" } ], "name": "send_message", @@ -291,7 +291,7 @@ "model": "gpt-4.1-mini", "agent_id": "agent-0", "tool_calls": null, - "tool_call_id": "call_pvBdfaX7NpaFcjrhqNCSyg7S", + "tool_call_id": "call_ktYL78z2mTh3w1t6pWixMHqA", "tool_returns": [ { "status": "success", @@ -299,7 +299,7 @@ "stderr": null } ], - "created_at": "2025-09-24T22:57:41.948586+00:00" + "created_at": "2025-09-26T05:01:20.466485+00:00" } ], "files_agents": [], @@ -367,7 +367,7 @@ "sources": [], "tools": [ { - "id": "tool-5", + "id": "tool-1", "tool_type": "custom", "description": "Analyze data and provide insights.", "source_type": "json", @@ -411,7 +411,7 @@ "metadata_": {} }, { - "id": "tool-1", + "id": "tool-6", "tool_type": "custom", "description": "Calculate the nth Fibonacci number.", "source_type": "json", @@ -447,9 +447,9 @@ "metadata_": {} }, { - "id": "tool-0", + "id": "tool-4", "tool_type": "letta_core", - "description": "Search prior conversation history using hybrid search (text + semantic similarity).\n\nExamples:\n # Search all messages\n conversation_search(query=\"project updates\")\n\n # Search only assistant messages\n conversation_search(query=\"error handling\", roles=[\"assistant\"])\n\n # Search with date range (inclusive of both dates)\n conversation_search(query=\"meetings\", start_date=\"2024-01-15\", end_date=\"2024-01-20\")\n # This includes all messages from Jan 15 00:00:00 through Jan 20 23:59:59\n\n # Search messages from a specific day (inclusive)\n conversation_search(query=\"bug reports\", start_date=\"2024-09-04\", end_date=\"2024-09-04\")\n # This includes ALL messages from September 4, 2024\n\n # Search with specific time boundaries\n conversation_search(query=\"deployment\", start_date=\"2024-01-15T09:00\", end_date=\"2024-01-15T17:30\")\n # This includes messages from 9 AM to 5:30 PM on Jan 15\n\n # Search with limit\n conversation_search(query=\"debugging\", limit=10)\n\n Returns:\n str: Query result string containing matching messages with timestamps and content.", + "description": "Search prior conversation history using hybrid search (text + semantic similarity).\n\nExamples:\n # Search all messages\n conversation_search(query=\"project updates\")\n\n # Search only assistant messages\n conversation_search(query=\"error handling\", roles=[\"assistant\"])\n\n # Search with date range (inclusive of both dates)\n conversation_search(query=\"meetings\", start_date=\"2024-01-15\", end_date=\"2024-01-20\")\n # This includes all messages from Jan 15 00:00:00 through Jan 20 23:59:59\n\n # Search messages from a specific day (inclusive)\n conversation_search(query=\"bug reports\", start_date=\"2024-09-04\", end_date=\"2024-09-04\")\n # This includes ALL messages from September 4, 2024\n\n # Search with specific time boundaries\n conversation_search(query=\"deployment\", start_date=\"2024-01-15T09:00\", end_date=\"2024-01-15T17:30\")\n # This includes messages from 9 AM to 5:30 PM on Jan 15\n\n # Search with limit\n conversation_search(query=\"debugging\", limit=10)", "source_type": "python", "name": "conversation_search", "tags": [ @@ -458,7 +458,7 @@ "source_code": null, "json_schema": { "name": "conversation_search", - "description": "Search prior conversation history using hybrid search (text + semantic similarity).\n\nExamples:\n # Search all messages\n conversation_search(query=\"project updates\")\n\n # Search only assistant messages\n conversation_search(query=\"error handling\", roles=[\"assistant\"])\n\n # Search with date range (inclusive of both dates)\n conversation_search(query=\"meetings\", start_date=\"2024-01-15\", end_date=\"2024-01-20\")\n # This includes all messages from Jan 15 00:00:00 through Jan 20 23:59:59\n\n # Search messages from a specific day (inclusive)\n conversation_search(query=\"bug reports\", start_date=\"2024-09-04\", end_date=\"2024-09-04\")\n # This includes ALL messages from September 4, 2024\n\n # Search with specific time boundaries\n conversation_search(query=\"deployment\", start_date=\"2024-01-15T09:00\", end_date=\"2024-01-15T17:30\")\n # This includes messages from 9 AM to 5:30 PM on Jan 15\n\n # Search with limit\n conversation_search(query=\"debugging\", limit=10)\n\n Returns:\n str: Query result string containing matching messages with timestamps and content.", + "description": "Search prior conversation history using hybrid search (text + semantic similarity).\n\nExamples:\n # Search all messages\n conversation_search(query=\"project updates\")\n\n # Search only assistant messages\n conversation_search(query=\"error handling\", roles=[\"assistant\"])\n\n # Search with date range (inclusive of both dates)\n conversation_search(query=\"meetings\", start_date=\"2024-01-15\", end_date=\"2024-01-20\")\n # This includes all messages from Jan 15 00:00:00 through Jan 20 23:59:59\n\n # Search messages from a specific day (inclusive)\n conversation_search(query=\"bug reports\", start_date=\"2024-09-04\", end_date=\"2024-09-04\")\n # This includes ALL messages from September 4, 2024\n\n # Search with specific time boundaries\n conversation_search(query=\"deployment\", start_date=\"2024-01-15T09:00\", end_date=\"2024-01-15T17:30\")\n # This includes messages from 9 AM to 5:30 PM on Jan 15\n\n # Search with limit\n conversation_search(query=\"debugging\", limit=10)", "parameters": { "type": "object", "properties": { @@ -506,7 +506,7 @@ "metadata_": {} }, { - "id": "tool-2", + "id": "tool-3", "tool_type": "custom", "description": "Get user preferences for a specific category.", "source_type": "json", @@ -542,9 +542,9 @@ "metadata_": {} }, { - "id": "tool-3", + "id": "tool-5", "tool_type": "letta_sleeptime_core", - "description": "The memory_insert command allows you to insert text at a specific location in a memory block.\n\nExamples:\n # Update a block containing information about the user (append to the end of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\")\n\n # Update a block containing information about the user (insert at the beginning of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\", insert_line=0)\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.", + "description": "The memory_insert command allows you to insert text at a specific location in a memory block.\n\nExamples:\n # Update a block containing information about the user (append to the end of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\")\n\n # Update a block containing information about the user (insert at the beginning of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\", insert_line=0)", "source_type": "python", "name": "memory_insert", "tags": [ @@ -553,7 +553,7 @@ "source_code": null, "json_schema": { "name": "memory_insert", - "description": "The memory_insert command allows you to insert text at a specific location in a memory block.\n\nExamples:\n # Update a block containing information about the user (append to the end of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\")\n\n # Update a block containing information about the user (insert at the beginning of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\", insert_line=0)\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.", + "description": "The memory_insert command allows you to insert text at a specific location in a memory block.\n\nExamples:\n # Update a block containing information about the user (append to the end of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\")\n\n # Update a block containing information about the user (insert at the beginning of the block)\n memory_insert(label=\"customer\", new_str=\"The customer's ticket number is 12345\", insert_line=0)", "parameters": { "type": "object", "properties": { @@ -586,9 +586,9 @@ "metadata_": {} }, { - "id": "tool-4", + "id": "tool-2", "tool_type": "letta_sleeptime_core", - "description": "The memory_replace command allows you to replace a specific string in a memory block with a new string. This is used for making precise edits.\n\nExamples:\n # Update a block containing information about the user\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")\n\n # Update a block containing a todo list\n memory_replace(label=\"todos\", old_str=\"- [ ] Step 5: Search the web\", new_str=\"- [x] Step 5: Search the web\")\n\n # Pass an empty string to\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"\")\n\n # Bad example - do NOT add (view-only) line numbers to the args\n memory_replace(label=\"human\", old_str=\"Line 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Bad example - do NOT include the number number warning either\n memory_replace(label=\"human\", old_str=\"# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\nLine 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Good example - no line numbers or line number warning (they are view-only), just the text\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")\n\n Returns:\n str: The success message", + "description": "The memory_replace command allows you to replace a specific string in a memory block with a new string. This is used for making precise edits.\n\nExamples:\n # Update a block containing information about the user\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")\n\n # Update a block containing a todo list\n memory_replace(label=\"todos\", old_str=\"- [ ] Step 5: Search the web\", new_str=\"- [x] Step 5: Search the web\")\n\n # Pass an empty string to\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"\")\n\n # Bad example - do NOT add (view-only) line numbers to the args\n memory_replace(label=\"human\", old_str=\"Line 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Bad example - do NOT include the number number warning either\n memory_replace(label=\"human\", old_str=\"# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\nLine 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Good example - no line numbers or line number warning (they are view-only), just the text\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")", "source_type": "python", "name": "memory_replace", "tags": [ @@ -597,7 +597,7 @@ "source_code": null, "json_schema": { "name": "memory_replace", - "description": "The memory_replace command allows you to replace a specific string in a memory block with a new string. This is used for making precise edits.\n\nExamples:\n # Update a block containing information about the user\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")\n\n # Update a block containing a todo list\n memory_replace(label=\"todos\", old_str=\"- [ ] Step 5: Search the web\", new_str=\"- [x] Step 5: Search the web\")\n\n # Pass an empty string to\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"\")\n\n # Bad example - do NOT add (view-only) line numbers to the args\n memory_replace(label=\"human\", old_str=\"Line 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Bad example - do NOT include the number number warning either\n memory_replace(label=\"human\", old_str=\"# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\nLine 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Good example - no line numbers or line number warning (they are view-only), just the text\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")\n\n Returns:\n str: The success message", + "description": "The memory_replace command allows you to replace a specific string in a memory block with a new string. This is used for making precise edits.\n\nExamples:\n # Update a block containing information about the user\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")\n\n # Update a block containing a todo list\n memory_replace(label=\"todos\", old_str=\"- [ ] Step 5: Search the web\", new_str=\"- [x] Step 5: Search the web\")\n\n # Pass an empty string to\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"\")\n\n # Bad example - do NOT add (view-only) line numbers to the args\n memory_replace(label=\"human\", old_str=\"Line 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Bad example - do NOT include the number number warning either\n memory_replace(label=\"human\", old_str=\"# NOTE: Line numbers shown below are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\nLine 1: Their name is Alice\", new_str=\"Line 1: Their name is Bob\")\n\n # Good example - no line numbers or line number warning (they are view-only), just the text\n memory_replace(label=\"human\", old_str=\"Their name is Alice\", new_str=\"Their name is Bob\")", "parameters": { "type": "object", "properties": { @@ -631,7 +631,7 @@ "metadata_": {} }, { - "id": "tool-6", + "id": "tool-0", "tool_type": "letta_core", "description": "Sends a message to the human user.", "source_type": "python", @@ -668,7 +668,7 @@ ], "mcp_servers": [], "metadata": { - "revision_id": "3d2e9fb40a3c" + "revision_id": "567e9fe06270" }, - "created_at": "2025-09-24T22:57:42.392726+00:00" + "created_at": "2025-09-26T05:01:21.039802+00:00" } diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 11396981..c824bd32 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -24,6 +24,7 @@ from letta.schemas.group import ManagerType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate from letta.schemas.organization import Organization +from letta.schemas.run import Run from letta.schemas.source import Source from letta.schemas.user import User from letta.server.server import SyncServer @@ -166,8 +167,17 @@ def agent_serialization_manager(server, default_user): async def send_message_to_agent(server: SyncServer, agent_state, actor: User, messages: list[MessageCreate]): + run = Run( + agent_id=agent_state.id, + ) + run = await server.run_manager.create_run( + pydantic_run=run, + actor=actor, + ) + agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor) result = await agent_loop.step( + run_id=run.id, input_messages=messages, ) return result @@ -1170,8 +1180,8 @@ class TestAgentFileImport: assert len(imported_agent.tools) == len(test_agent.tools) assert len(imported_agent.memory.blocks) == len(test_agent.memory.blocks) - original_messages = await server.message_manager.list_messages_for_agent_async(test_agent.id, default_user) - imported_messages = await server.message_manager.list_messages_for_agent_async(imported_agent_id, other_user) + original_messages = await server.message_manager.list_messages(actor=default_user, agent_id=test_agent.id) + imported_messages = await server.message_manager.list_messages(actor=other_user, agent_id=imported_agent_id) assert len(imported_messages) == len(original_messages) @@ -1191,7 +1201,7 @@ class TestAgentFileImport: assert len(imported_agent.message_ids) == len(test_agent.message_ids) - imported_messages = await server.message_manager.list_messages_for_agent_async(imported_agent_id, other_user) + imported_messages = await server.message_manager.list_messages(actor=other_user, agent_id=imported_agent_id) imported_message_ids = {msg.id for msg in imported_messages} for in_context_id in imported_agent.message_ids: @@ -1500,7 +1510,7 @@ class TestAgentFileEdgeCases: # Verify all messages imported correctly assert result.success imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") - imported_messages = await server.message_manager.list_messages_for_agent_async(imported_agent_id, other_user) + imported_messages = await server.message_manager.list_messages(actor=other_user, agent_id=imported_agent_id) assert len(imported_messages) >= num_messages @@ -1533,7 +1543,7 @@ class TestAgentFileValidation: async def test_message_schema_conversion(self, test_agent, server, default_user): """Test MessageSchema.from_message conversion.""" # Get a message from the test agent - messages = await server.message_manager.list_messages_for_agent_async(test_agent.id, default_user) + messages = await server.message_manager.list_messages(actor=default_user, agent_id=test_agent.id) if messages: original_message = messages[0] diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 1327c367..7b78f3bb 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -2212,6 +2212,44 @@ def test_upsert_tools(client: LettaSDKClient): client.tools.delete(tool.id) +def test_run_list(client: LettaSDKClient): + """Test listing runs.""" + + # create an agent + agent = client.agents.create( + name="test_run_list", + memory_blocks=[ + CreateBlock(label="persona", value="you are a helpful assistant"), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + # message an agent + client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate(role="user", content="Hello, how are you?"), + ], + ) + + # message an agent async + async_run = client.agents.messages.create_async( + agent_id=agent.id, + messages=[ + MessageCreate(role="user", content="Hello, how are you?"), + ], + ) + + # list runs + runs = client.runs.list(agent_ids=[agent.id]) + assert len(runs) == 2 + assert async_run.id in [run.id for run in runs] + + # test get run + run = client.runs.retrieve(runs[0].id) + assert run.agent_id == agent.id + @pytest.mark.asyncio async def test_create_batch(client: LettaSDKClient, server: SyncServer): # create agents diff --git a/tests/test_sources.py b/tests/test_sources.py index f71422d8..39eb82ae 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -1226,8 +1226,6 @@ def test_letta_free_embedding(disable_pinecone, disable_turbopuffer, client: Let # verify source was created with correct embedding assert source.name == "test_letta_free_source" - print("\n\n\n\ntest") - print(source.embedding_config) # assert source.embedding_config.embedding_model == "letta-free" # upload test.txt file