diff --git a/.github/workflows/core-unit-sqlite-test.yaml b/.github/workflows/core-unit-sqlite-test.yaml index fe1a0d2f..87ff9517 100644 --- a/.github/workflows/core-unit-sqlite-test.yaml +++ b/.github/workflows/core-unit-sqlite-test.yaml @@ -28,7 +28,6 @@ jobs: install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra sqlite' timeout-minutes: 15 ref: ${{ github.event.pull_request.head.sha || github.sha }} - matrix-strategy: | { "fail-fast": false, diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index b58311ef..9713cd0a 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -10,6 +10,29 @@ jobs: test: runs-on: ubuntu-latest timeout-minutes: 15 + env: + # Database configuration - these will be used by dev-compose.yaml + LETTA_PG_DB: letta + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_HOST: pgvector_db # Internal Docker service name + LETTA_PG_PORT: 5432 + # Server configuration for tests + LETTA_SERVER_PASS: test_server_token + LETTA_SERVER_URL: http://localhost:8283 + # API keys + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + # Additional API keys that dev-compose.yaml expects (optional) + GROQ_API_KEY: "" + ANTHROPIC_API_KEY: "" + OLLAMA_BASE_URL: "" + AZURE_API_KEY: "" + AZURE_BASE_URL: "" + AZURE_API_VERSION: "" + GEMINI_API_KEY: "" + VLLM_API_BASE: "" + OPENLLM_AUTH_TYPE: "" + OPENLLM_API_KEY: "" steps: - name: Checkout uses: actions/checkout@v4 @@ -32,14 +55,8 @@ jobs: chmod -R 755 /home/runner/.letta/logs - name: Build and run docker dev server - env: - LETTA_PG_DB: letta - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_PORT: 8888 - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: | + # dev-compose.yaml will use the environment variables we set above docker compose -f dev-compose.yaml up --build -d - name: Wait for service @@ -47,13 +64,6 @@ jobs: - name: Run tests with pytest env: - LETTA_PG_DB: letta - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_PORT: 8888 - LETTA_SERVER_PASS: test_server_token - LETTA_SERVER_URL: http://localhost:8283 - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | uv sync --extra dev --extra postgres --extra sqlite diff --git a/alembic/env.py b/alembic/env.py index dac40ea4..4202cdfa 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -15,8 +15,13 @@ letta_config = LettaConfig.load() config = context.config if settings.database_engine is DatabaseChoice.POSTGRES: - config.set_main_option("sqlalchemy.url", settings.letta_pg_uri) - print("Using database: ", settings.letta_pg_uri) + # Convert PostgreSQL URI to sync format for alembic using common utility + from letta.database_utils import get_database_uri_for_context + + sync_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "alembic") + + config.set_main_option("sqlalchemy.url", sync_pg_uri) + print("Using database: ", sync_pg_uri) else: config.set_main_option("sqlalchemy.url", "sqlite:///" + os.path.join(letta_config.recall_storage_path, "sqlite.db")) diff --git a/alembic/versions/c734cfc0d595_add_runs_metrics_table.py b/alembic/versions/c734cfc0d595_add_runs_metrics_table.py new file mode 100644 index 00000000..6f0db489 --- /dev/null +++ b/alembic/versions/c734cfc0d595_add_runs_metrics_table.py @@ -0,0 +1,55 @@ +"""add runs_metrics table + +Revision ID: c734cfc0d595 +Revises: 038e68cdf0df +Create Date: 2025-10-08 14:35:23.302204 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c734cfc0d595" +down_revision: Union[str, None] = "038e68cdf0df" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "run_metrics", + sa.Column("id", sa.String(), nullable=False), + sa.Column("run_start_ns", sa.BigInteger(), nullable=True), + sa.Column("run_ns", sa.BigInteger(), nullable=True), + sa.Column("num_steps", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("project_id", sa.String(), nullable=True), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("base_template_id", sa.String(), nullable=True), + sa.Column("template_id", sa.String(), nullable=True), + sa.Column("deployment_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["id"], ["runs.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("run_metrics") + # ### end Alembic commands ### diff --git a/dev-compose.yaml b/dev-compose.yaml index 36fd5c54..81d08478 100644 --- a/dev-compose.yaml +++ b/dev-compose.yaml @@ -31,8 +31,9 @@ services: - LETTA_PG_DB=${LETTA_PG_DB:-letta} - LETTA_PG_USER=${LETTA_PG_USER:-letta} - LETTA_PG_PASSWORD=${LETTA_PG_PASSWORD:-letta} - - LETTA_PG_HOST=pgvector_db - - LETTA_PG_PORT=5432 + - LETTA_PG_HOST=${LETTA_PG_HOST:-pgvector_db} + - LETTA_PG_PORT=${LETTA_PG_PORT:-5432} + - LETTA_PG_URI=${LETTA_PG_URI:-postgresql://${LETTA_PG_USER:-letta}:${LETTA_PG_PASSWORD:-letta}@${LETTA_PG_HOST:-pgvector_db}:${LETTA_PG_PORT:-5432}/${LETTA_PG_DB:-letta}} - LETTA_DEBUG=True - OPENAI_API_KEY=${OPENAI_API_KEY} - GROQ_API_KEY=${GROQ_API_KEY} diff --git a/fern/openapi.json b/fern/openapi.json index d04395fb..9ef69f2c 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -10698,6 +10698,47 @@ } } }, + "/v1/runs/{run_id}/metrics": { + "get": { + "tags": ["runs"], + "summary": "Retrieve Metrics For Run", + "description": "Get run metrics by run ID.", + "operationId": "retrieve_metrics_for_run", + "parameters": [ + { + "name": "run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Run Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunMetrics" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/runs/{run_id}/steps": { "get": { "tags": ["runs"], @@ -31576,6 +31617,103 @@ "title": "Run", "description": "Representation of a run - a conversation or processing session for an agent.\nRuns track when agents process messages and maintain the relationship between agents, steps, and messages.\n\nParameters:\n id (str): The unique identifier of the run (prefixed with 'run-').\n status (JobStatus): The current status of the run.\n created_at (datetime): The timestamp when the run was created.\n completed_at (datetime): The timestamp when the run was completed.\n agent_id (str): The unique identifier of the agent associated with the run.\n stop_reason (StopReasonType): The reason why the run was stopped.\n background (bool): Whether the run was created in background mode.\n metadata (dict): Additional metadata for the run.\n request_config (LettaRequestConfig): The request configuration for the run." }, + "RunMetrics": { + "properties": { + "id": { + "type": "string", + "title": "Id", + "description": "The id of the run this metric belongs to (matches runs.id)." + }, + "agent_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Agent Id", + "description": "The unique identifier of the agent." + }, + "project_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Project Id", + "description": "The project that the run belongs to (cloud only)." + }, + "run_start_ns": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Run Start Ns", + "description": "The timestamp of the start of the run in nanoseconds." + }, + "run_ns": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Run Ns", + "description": "Total time for the run in nanoseconds." + }, + "num_steps": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Num Steps", + "description": "The number of steps in the run." + }, + "template_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Template Id", + "description": "The template ID that the run belongs to (cloud only)." + }, + "base_template_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Template Id", + "description": "The base template ID that the run belongs to (cloud only)." + } + }, + "additionalProperties": false, + "type": "object", + "required": ["id"], + "title": "RunMetrics" + }, "RunStatus": { "type": "string", "enum": ["created", "running", "completed", "failed", "cancelled"], diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index edda758d..1df13a1c 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -595,9 +595,30 @@ class LettaAgentV3(LettaAgentV2): # -1. no tool call, no content if tool_call is None and (content is None or len(content) == 0): # Edge case is when there's also no content - basically, the LLM "no-op'd" - # In this case, we actually do not want to persist the no-op message - continue_stepping, heartbeat_reason, stop_reason = False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value) - messages_to_persist = initial_messages or [] + # If RequiredBeforeExitToolRule exists and not all required tools have been called, + # inject a rule-violation heartbeat to keep looping and inform the model. + uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools])) + if uncalled: + # TODO: we may need to change this to not have a "heartbeat" prefix for v3? + heartbeat_reason = ( + f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop." + ) + from letta.server.rest_api.utils import create_heartbeat_system_message + + heartbeat_msg = create_heartbeat_system_message( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_call_success=True, + timezone=agent_state.timezone, + heartbeat_reason=heartbeat_reason, + run_id=run_id, + ) + messages_to_persist = (initial_messages or []) + [heartbeat_msg] + continue_stepping, stop_reason = True, None + else: + # In this case, we actually do not want to persist the no-op message + continue_stepping, heartbeat_reason, stop_reason = False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value) + messages_to_persist = initial_messages or [] # 0. If there's no tool call, we can early exit elif tool_call is None: @@ -627,7 +648,8 @@ class LettaAgentV3(LettaAgentV2): run_id=run_id, is_approval_response=is_approval or is_denial, force_set_request_heartbeat=False, - add_heartbeat_on_continue=False, + # If we're continuing due to a required-before-exit rule, include a heartbeat to guide the model + add_heartbeat_on_continue=bool(heartbeat_reason), ) messages_to_persist = (initial_messages or []) + assistant_message @@ -843,7 +865,13 @@ class LettaAgentV3(LettaAgentV2): stop_reason: LettaStopReason | None = None if tool_call_name is None: - # No tool call? End loop + # No tool call – if there are required-before-exit tools uncalled, keep stepping + # and provide explicit feedback to the model; otherwise end the loop. + uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools])) + if uncalled and not is_final_step: + reason = f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop." + return True, reason, None + # No required tools remaining → end turn return False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value) else: if tool_rule_violated: diff --git a/letta/database_utils.py b/letta/database_utils.py new file mode 100644 index 00000000..f5c499cb --- /dev/null +++ b/letta/database_utils.py @@ -0,0 +1,161 @@ +""" +Database URI utilities for consistent database connection handling across the application. + +This module provides utilities for parsing and converting database URIs to ensure +consistent behavior between the main application, alembic migrations, and other +database-related components. +""" + +from typing import Optional +from urllib.parse import urlparse, urlunparse + + +def parse_database_uri(uri: str) -> dict[str, Optional[str]]: + """ + Parse a database URI into its components. + + Args: + uri: Database URI (e.g., postgresql://user:pass@host:port/db) + + Returns: + Dictionary with parsed components: scheme, driver, user, password, host, port, database + """ + parsed = urlparse(uri) + + # Extract driver from scheme (e.g., postgresql+asyncpg -> asyncpg) + scheme_parts = parsed.scheme.split("+") + base_scheme = scheme_parts[0] if scheme_parts else "" + driver = scheme_parts[1] if len(scheme_parts) > 1 else None + + return { + "scheme": base_scheme, + "driver": driver, + "user": parsed.username, + "password": parsed.password, + "host": parsed.hostname, + "port": str(parsed.port) if parsed.port else None, + "database": parsed.path.lstrip("/") if parsed.path else None, + "query": parsed.query, + "fragment": parsed.fragment, + } + + +def build_database_uri( + scheme: str = "postgresql", + driver: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[str] = None, + database: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, +) -> str: + """ + Build a database URI from components. + + Args: + scheme: Base scheme (e.g., "postgresql") + driver: Driver name (e.g., "asyncpg", "pg8000") + user: Username + password: Password + host: Hostname + port: Port number + database: Database name + query: Query string + fragment: Fragment + + Returns: + Complete database URI + """ + # Combine scheme and driver + full_scheme = f"{scheme}+{driver}" if driver else scheme + + # Build netloc (user:password@host:port) + netloc_parts = [] + if user: + if password: + netloc_parts.append(f"{user}:{password}") + else: + netloc_parts.append(user) + + if host: + if port: + netloc_parts.append(f"{host}:{port}") + else: + netloc_parts.append(host) + + netloc = "@".join(netloc_parts) if netloc_parts else "" + + # Build path + path = f"/{database}" if database else "" + + # Build the URI + return urlunparse((full_scheme, netloc, path, "", query or "", fragment or "")) + + +def convert_to_async_uri(uri: str) -> str: + """ + Convert a database URI to use the asyncpg driver for async operations. + + Args: + uri: Original database URI + + Returns: + URI with asyncpg driver and ssl parameter adjustments + """ + components = parse_database_uri(uri) + + # Convert to asyncpg driver + components["driver"] = "asyncpg" + + # Build the new URI + new_uri = build_database_uri(**components) + + # Replace sslmode= with ssl= for asyncpg compatibility + new_uri = new_uri.replace("sslmode=", "ssl=") + + return new_uri + + +def convert_to_sync_uri(uri: str) -> str: + """ + Convert a database URI to use the pg8000 driver for sync operations (alembic). + + Args: + uri: Original database URI + + Returns: + URI with pg8000 driver and sslmode parameter adjustments + """ + components = parse_database_uri(uri) + + # Convert to pg8000 driver + components["driver"] = "pg8000" + + # Build the new URI + new_uri = build_database_uri(**components) + + # Replace ssl= with sslmode= for pg8000 compatibility + new_uri = new_uri.replace("ssl=", "sslmode=") + + return new_uri + + +def get_database_uri_for_context(uri: str, context: str = "async") -> str: + """ + Get the appropriate database URI for a specific context. + + Args: + uri: Original database URI + context: Context type ("async" for asyncpg, "sync" for pg8000, "alembic" for pg8000) + + Returns: + URI formatted for the specified context + """ + if context in ["async"]: + return convert_to_async_uri(uri) + elif context in ["sync", "alembic"]: + return convert_to_sync_uri(uri) + else: + raise ValueError(f"Unknown context: {context}. Must be 'async', 'sync', or 'alembic'") diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 9a500fe2..afb95488 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -325,6 +325,7 @@ class AnthropicClient(LLMClientBase): data["system"] = self._add_cache_control_to_system_message(system_content) data["messages"] = PydanticMessage.to_anthropic_dicts_from_list( messages=messages[1:], + current_model=llm_config.model, inner_thoughts_xml_tag=inner_thoughts_xml_tag, put_inner_thoughts_in_kwargs=put_kwargs, # if react, use native content + strip heartbeats diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index de7c4d81..4598cb7d 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -311,6 +311,7 @@ class GoogleVertexClient(LLMClientBase): contents = self.add_dummy_model_messages( PydanticMessage.to_google_dicts_from_list( messages, + current_model=llm_config.model, put_inner_thoughts_in_kwargs=False if agent_type == AgentType.letta_v1_agent else True, native_content=True if agent_type == AgentType.letta_v1_agent else False, ), diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 72d5f056..a834ab90 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -27,6 +27,7 @@ from letta.orm.prompt import Prompt from letta.orm.provider import Provider from letta.orm.provider_trace import ProviderTrace from letta.orm.run import Run +from letta.orm.run_metrics import RunMetrics from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/run_metrics.py b/letta/orm/run_metrics.py new file mode 100644 index 00000000..71e9aa84 --- /dev/null +++ b/letta/orm/run_metrics.py @@ -0,0 +1,82 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import BigInteger, ForeignKey, Integer, String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped, Session, mapped_column, relationship + +from letta.orm.mixins import AgentMixin, OrganizationMixin, ProjectMixin, TemplateMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics +from letta.schemas.user import User +from letta.settings import DatabaseChoice, settings + +if TYPE_CHECKING: + from letta.orm.agent import Agent + from letta.orm.run import Run + from letta.orm.step import Step + + +class RunMetrics(SqlalchemyBase, ProjectMixin, AgentMixin, OrganizationMixin, TemplateMixin): + """Tracks performance metrics for agent steps.""" + + __tablename__ = "run_metrics" + __pydantic_model__ = PydanticRunMetrics + + id: Mapped[str] = mapped_column( + ForeignKey("runs.id", ondelete="CASCADE"), + primary_key=True, + doc="The unique identifier of the run this metric belongs to (also serves as PK)", + ) + run_start_ns: Mapped[Optional[int]] = mapped_column( + BigInteger, + nullable=True, + doc="The timestamp of the start of the run in nanoseconds", + ) + run_ns: Mapped[Optional[int]] = mapped_column( + BigInteger, + nullable=True, + doc="Total time for the run in nanoseconds", + ) + num_steps: Mapped[Optional[int]] = mapped_column( + Integer, + nullable=True, + doc="The number of steps in the run", + ) + run: Mapped[Optional["Run"]] = relationship("Run", foreign_keys=[id]) + agent: Mapped[Optional["Agent"]] = relationship("Agent") + + def create( + self, + db_session: Session, + actor: Optional[User] = None, + no_commit: bool = False, + ) -> "RunMetrics": + """Override create to handle SQLite timestamp issues""" + # For SQLite, explicitly set timestamps as server_default may not work + if settings.database_engine == DatabaseChoice.SQLITE: + now = datetime.now(timezone.utc) + if not self.created_at: + self.created_at = now + if not self.updated_at: + self.updated_at = now + + return super().create(db_session, actor=actor, no_commit=no_commit) + + async def create_async( + self, + db_session: AsyncSession, + actor: Optional[User] = None, + no_commit: bool = False, + no_refresh: bool = False, + ) -> "RunMetrics": + """Override create_async to handle SQLite timestamp issues""" + # For SQLite, explicitly set timestamps as server_default may not work + if settings.database_engine == DatabaseChoice.SQLITE: + now = datetime.now(timezone.utc) + if not self.created_at: + self.created_at = now + if not self.updated_at: + self.updated_at = now + + return await super().create_async(db_session, actor=actor, no_commit=no_commit, no_refresh=no_refresh) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 66b82590..31173aa7 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -965,7 +965,13 @@ class Message(BaseMessage): } elif self.role == "assistant" or self.role == "approval": - assert self.tool_calls is not None or text_content is not None, vars(self) + try: + assert self.tool_calls is not None or text_content is not None, vars(self) + except AssertionError as e: + # relax check if this message only contains reasoning content + if self.content is not None and len(self.content) > 0 and isinstance(self.content[0], ReasoningContent): + return None + raise e # if native content, then put it directly inside the content if native_content: @@ -1040,6 +1046,7 @@ class Message(BaseMessage): put_inner_thoughts_in_kwargs: bool = False, use_developer_message: bool = False, ) -> List[dict]: + messages = Message.filter_messages_for_llm_api(messages) result = [ m.to_openai_dict( max_tool_id_length=max_tool_id_length, @@ -1149,6 +1156,7 @@ class Message(BaseMessage): messages: List[Message], max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN, ) -> List[dict]: + messages = Message.filter_messages_for_llm_api(messages) result = [] for message in messages: result.extend(message.to_openai_responses_dicts(max_tool_id_length=max_tool_id_length)) @@ -1156,6 +1164,7 @@ class Message(BaseMessage): def to_anthropic_dict( self, + current_model: str, inner_thoughts_xml_tag="thinking", put_inner_thoughts_in_kwargs: bool = False, # if true, then treat the content field as AssistantMessage @@ -1242,20 +1251,22 @@ class Message(BaseMessage): for content_part in self.content: # TextContent, ImageContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent if isinstance(content_part, ReasoningContent): - content.append( - { - "type": "thinking", - "thinking": content_part.reasoning, - "signature": content_part.signature, - } - ) + if current_model == self.model: + content.append( + { + "type": "thinking", + "thinking": content_part.reasoning, + "signature": content_part.signature, + } + ) elif isinstance(content_part, RedactedReasoningContent): - content.append( - { - "type": "redacted_thinking", - "data": content_part.data, - } - ) + if current_model == self.model: + content.append( + { + "type": "redacted_thinking", + "data": content_part.data, + } + ) elif isinstance(content_part, TextContent): content.append( { @@ -1272,20 +1283,22 @@ class Message(BaseMessage): if self.content is not None and len(self.content) >= 1: for content_part in self.content: if isinstance(content_part, ReasoningContent): - content.append( - { - "type": "thinking", - "thinking": content_part.reasoning, - "signature": content_part.signature, - } - ) + if current_model == self.model: + content.append( + { + "type": "thinking", + "thinking": content_part.reasoning, + "signature": content_part.signature, + } + ) if isinstance(content_part, RedactedReasoningContent): - content.append( - { - "type": "redacted_thinking", - "data": content_part.data, - } - ) + if current_model == self.model: + content.append( + { + "type": "redacted_thinking", + "data": content_part.data, + } + ) if isinstance(content_part, TextContent): content.append( { @@ -1349,14 +1362,17 @@ class Message(BaseMessage): @staticmethod def to_anthropic_dicts_from_list( messages: List[Message], + current_model: str, inner_thoughts_xml_tag: str = "thinking", put_inner_thoughts_in_kwargs: bool = False, # if true, then treat the content field as AssistantMessage native_content: bool = False, strip_request_heartbeat: bool = False, ) -> List[dict]: + messages = Message.filter_messages_for_llm_api(messages) result = [ m.to_anthropic_dict( + current_model=current_model, inner_thoughts_xml_tag=inner_thoughts_xml_tag, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, native_content=native_content, @@ -1369,6 +1385,7 @@ class Message(BaseMessage): def to_google_dict( self, + current_model: str, put_inner_thoughts_in_kwargs: bool = True, # if true, then treat the content field as AssistantMessage native_content: bool = False, @@ -1484,11 +1501,12 @@ class Message(BaseMessage): for content in self.content: if isinstance(content, TextContent): native_part = {"text": content.text} - if content.signature: + if content.signature and current_model == self.model: native_part["thought_signature"] = content.signature native_google_content_parts.append(native_part) elif isinstance(content, ReasoningContent): - native_google_content_parts.append({"text": content.reasoning, "thought": True}) + if current_model == self.model: + native_google_content_parts.append({"text": content.reasoning, "thought": True}) elif isinstance(content, ToolCallContent): native_part = { "function_call": { @@ -1496,7 +1514,7 @@ class Message(BaseMessage): "args": content.input, }, } - if content.signature: + if content.signature and current_model == self.model: native_part["thought_signature"] = content.signature native_google_content_parts.append(native_part) else: @@ -1554,11 +1572,14 @@ class Message(BaseMessage): @staticmethod def to_google_dicts_from_list( messages: List[Message], + current_model: str, put_inner_thoughts_in_kwargs: bool = True, native_content: bool = False, ): + messages = Message.filter_messages_for_llm_api(messages) result = [ m.to_google_dict( + current_model=current_model, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, native_content=native_content, ) @@ -1567,6 +1588,45 @@ class Message(BaseMessage): result = [m for m in result if m is not None] return result + def is_approval_request(self) -> bool: + return self.role == "approval" and self.tool_calls is not None and len(self.tool_calls) > 0 + + def is_approval_response(self) -> bool: + return self.role == "approval" and self.tool_calls is None and self.approve is not None + + def is_summarization_message(self) -> bool: + return ( + self.role == "user" + and self.content is not None + and len(self.content) == 1 + and isinstance(self.content[0], TextContent) + and "system_alert" in self.content[0].text + ) + + @staticmethod + def filter_messages_for_llm_api( + messages: List[Message], + ) -> List[Message]: + messages = [m for m in messages if m is not None] + if len(messages) == 0: + return [] + # Add special handling for legacy bug where summarization triggers in the middle of hitl + messages_to_filter = [] + for i in range(len(messages) - 1): + first_message_is_approval = messages[i].is_approval_request() + second_message_is_summary = messages[i + 1].is_summarization_message() + third_message_is_optional_approval = i + 2 >= len(messages) or messages[i + 2].is_approval_response() + if first_message_is_approval and second_message_is_summary and third_message_is_optional_approval: + messages_to_filter.append(messages[i]) + for idx in reversed(messages_to_filter): # reverse to avoid index shift + messages.remove(idx) + + # Filter last message if it is a lone approval request without a response - this only occurs for token counting + if messages[-1].role == "approval" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0: + messages.remove(messages[-1]) + + return messages + @staticmethod def generate_otid_from_id(message_id: str, index: int) -> str: """ diff --git a/letta/schemas/run_metrics.py b/letta/schemas/run_metrics.py new file mode 100644 index 00000000..40971ab6 --- /dev/null +++ b/letta/schemas/run_metrics.py @@ -0,0 +1,21 @@ +from typing import Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase + + +class RunMetricsBase(LettaBase): + __id_prefix__ = "run" + + +class RunMetrics(RunMetricsBase): + id: str = Field(..., description="The id of the run this metric belongs to (matches runs.id).") + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.") + agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") + project_id: Optional[str] = Field(None, description="The project that the run belongs to (cloud only).") + run_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the run in nanoseconds.") + run_ns: Optional[int] = Field(None, description="Total time for the run in nanoseconds.") + num_steps: Optional[int] = Field(None, description="The number of steps in the run.") + template_id: Optional[str] = Field(None, description="The template ID that the run belongs to (cloud only).") + base_template_id: Optional[str] = Field(None, description="The base template ID that the run belongs to (cloud only).") diff --git a/letta/server/db.py b/letta/server/db.py index 4c1e3223..f5b44479 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -10,18 +10,11 @@ from sqlalchemy.ext.asyncio import ( create_async_engine, ) +from letta.database_utils import get_database_uri_for_context from letta.settings import settings -# Convert PostgreSQL URI to async format -pg_uri = settings.letta_pg_uri -if pg_uri.startswith("postgresql://"): - async_pg_uri = pg_uri.replace("postgresql://", "postgresql+asyncpg://") -else: - # Handle other URI formats (e.g., postgresql+pg8000://) - async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri - -# Replace sslmode with ssl for asyncpg -async_pg_uri = async_pg_uri.replace("sslmode=", "ssl=") +# Convert PostgreSQL URI to async format using common utility +async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async") # Build engine configuration based on settings engine_args = { diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 083b09db..f79318e4 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -13,6 +13,7 @@ from letta.schemas.letta_request import RetrieveStreamRequest from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.run import Run +from letta.schemas.run_metrics import RunMetrics from letta.schemas.step import Step from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator @@ -224,6 +225,23 @@ async def retrieve_run_usage( raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") +@router.get("/{run_id}/metrics", response_model=RunMetrics, operation_id="retrieve_metrics_for_run") +async def retrieve_metrics_for_run( + run_id: str, + headers: HeaderParams = Depends(get_headers), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Get run metrics by run ID. + """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() + return await runs_manager.get_run_metrics_async(run_id=run_id, actor=actor) + except NoResultFound: + raise HTTPException(status_code=404, detail="Run metrics not found") + + @router.get( "/{run_id}/steps", response_model=List[Step], @@ -247,18 +265,14 @@ async def list_run_steps( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) runs_manager = RunManager() - try: - steps = await runs_manager.get_run_steps( - run_id=run_id, - actor=actor, - limit=limit, - before=before, - after=after, - ascending=(order == "asc"), - ) - return steps - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) + return await runs_manager.get_run_steps( + run_id=run_id, + actor=actor, + limit=limit, + before=before, + after=after, + ascending=(order == "asc"), + ) @router.delete("/{run_id}", response_model=Run, operation_id="delete_run") @@ -272,12 +286,7 @@ async def delete_run( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) runs_manager = RunManager() - - try: - run = await runs_manager.delete_run_by_id(run_id=run_id, actor=actor) - return run - except NoResultFound: - raise HTTPException(status_code=404, detail="Run not found") + return await runs_manager.delete_run_by_id(run_id=run_id, actor=actor) @router.post( diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 96aecb0e..a940cf55 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -74,7 +74,7 @@ class AnthropicTokenCounter(TokenCounter): return await self.client.count_tokens(model=self.model, tools=tools) def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: - return Message.to_anthropic_dicts_from_list(messages) + return Message.to_anthropic_dicts_from_list(messages, current_model=self.model) class TiktokenCounter(TokenCounter): diff --git a/letta/services/helpers/run_manager_helper.py b/letta/services/helpers/run_manager_helper.py index b8d1fd29..90d6aab4 100644 --- a/letta/services/helpers/run_manager_helper.py +++ b/letta/services/helpers/run_manager_helper.py @@ -2,14 +2,10 @@ from datetime import datetime from typing import Optional from sqlalchemy import asc, desc, nulls_last, select -from letta.settings import DatabaseChoice, settings from letta.orm.run import Run as RunModel -from letta.settings import DatabaseChoice, settings -from sqlalchemy import asc, desc -from typing import Optional - from letta.services.helpers.agent_manager_helper import _cursor_filter +from letta.settings import DatabaseChoice, settings async def _apply_pagination_async( @@ -29,17 +25,11 @@ async def _apply_pagination_async( sort_nulls_last = False if after: - result = ( - await session.execute( - select(sort_column, RunModel.id).where(RunModel.id == after) - ) - ).first() + result = (await session.execute(select(sort_column, RunModel.id).where(RunModel.id == after))).first() if result: after_sort_value, after_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp - if settings.database_engine is DatabaseChoice.SQLITE and isinstance( - after_sort_value, datetime - ): + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime): after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where( _cursor_filter( @@ -53,17 +43,11 @@ async def _apply_pagination_async( ) if before: - result = ( - await session.execute( - select(sort_column, RunModel.id).where(RunModel.id == before) - ) - ).first() + result = (await session.execute(select(sort_column, RunModel.id).where(RunModel.id == before))).first() if result: before_sort_value, before_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp - if settings.database_engine is DatabaseChoice.SQLITE and isinstance( - before_sort_value, datetime - ): + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime): before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where( _cursor_filter( diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index cdd679d2..44838536 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -8,9 +8,11 @@ from sqlalchemy.orm import Session from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger +from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.orm.run import Run as RunModel +from letta.orm.run_metrics import RunMetrics as RunMetricsModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel from letta.otel.tracing import log_event, trace_method @@ -21,6 +23,7 @@ from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message as PydanticMessage from letta.schemas.run import Run as PydanticRun, RunUpdate +from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics from letta.schemas.step import Step as PydanticStep from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser @@ -62,6 +65,23 @@ class RunManager: run = RunModel(**run_data) run.organization_id = organization_id run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True) + + # Create run metrics with start timestamp + import time + + # Get the project_id from the agent + agent = await session.get(AgentModel, agent_id) + project_id = agent.project_id if agent else None + + metrics = RunMetricsModel( + id=run.id, + organization_id=organization_id, + agent_id=agent_id, + project_id=project_id, + run_start_ns=int(time.time() * 1e9), # Current time in nanoseconds + num_steps=0, # Initialize to 0 + ) + await metrics.create_async(session) await session.commit() return run.to_pydantic() @@ -178,6 +198,21 @@ class RunManager: await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) final_metadata = run.metadata_ pydantic_run = run.to_pydantic() + + await session.commit() + + # update run metrics table + num_steps = len(await self.step_manager.list_steps_async(run_id=run_id, actor=actor)) + async with db_registry.async_session() as session: + metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor) + # Calculate runtime if run is completing + if is_terminal_update and metrics.run_start_ns: + import time + + current_ns = int(time.time() * 1e9) + metrics.run_ns = current_ns - metrics.run_start_ns + metrics.num_steps = num_steps + await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) await session.commit() # Dispatch callback outside of database session if needed @@ -299,3 +334,31 @@ class RunManager: raise NoResultFound(f"Run with id {run_id} not found") pydantic_run = run.to_pydantic() return pydantic_run.request_config + + @enforce_types + async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics: + """Get metrics for a run.""" + async with db_registry.async_session() as session: + metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor) + return metrics.to_pydantic() + + @enforce_types + async def get_run_steps( + self, + run_id: str, + actor: PydanticUser, + limit: Optional[int] = 100, + before: Optional[str] = None, + after: Optional[str] = None, + ascending: bool = False, + ) -> List[PydanticStep]: + """Get steps for a run.""" + async with db_registry.async_session() as session: + run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) + if not run: + raise NoResultFound(f"Run with id {run_id} not found") + + steps = await self.step_manager.list_steps_async( + actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc" + ) + return steps diff --git a/letta/system.py b/letta/system.py index f76c836e..ec3cda2c 100644 --- a/letta/system.py +++ b/letta/system.py @@ -248,7 +248,11 @@ def unpack_message(packed_message: str) -> str: warnings.warn(f"Was unable to find 'message' field in packed message object: '{packed_message}'") return packed_message else: - message_type = message_json["type"] + try: + message_type = message_json["type"] + except: + return packed_message + if message_type != "user_message": warnings.warn(f"Expected type to be 'user_message', but was '{message_type}', so not unpacking: '{packed_message}'") return packed_message diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 7a50685b..5a96f6e7 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -254,18 +254,33 @@ def test_send_message_after_turning_off_requires_approval( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 6 or len(messages) == 8 or len(messages) == 9 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "tool_call_message" - assert messages[3].message_type == "tool_return_message" - if len(messages) == 8: - assert messages[4].message_type == "reasoning_message" - assert messages[5].message_type == "assistant_message" - elif len(messages) == 9: - assert messages[4].message_type == "reasoning_message" - assert messages[5].message_type == "tool_call_message" - assert messages[6].message_type == "tool_return_message" + assert 6 <= len(messages) <= 9 + idx = 0 + + assert messages[idx].message_type == "reasoning_message" + idx += 1 + + try: + assert messages[idx].message_type == "assistant_message" + idx += 1 + except: + pass + + assert messages[idx].message_type == "tool_call_message" + idx += 1 + assert messages[idx].message_type == "tool_return_message" + idx += 1 + + assert messages[idx].message_type == "reasoning_message" + idx += 1 + try: + assert messages[idx].message_type == "assistant_message" + idx += 1 + except: + assert messages[idx].message_type == "tool_call_message" + idx += 1 + assert messages[idx].message_type == "tool_return_message" + idx += 1 # ------------------------------ diff --git a/tests/managers/test_run_manager.py b/tests/managers/test_run_manager.py index 3994735c..f64f5ee1 100644 --- a/tests/managers/test_run_manager.py +++ b/tests/managers/test_run_manager.py @@ -1074,6 +1074,243 @@ async def test_get_run_request_config_nonexistent_run(server: SyncServer, defaul await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user) +# ====================================================================================================================== +# RunManager Tests - Run Metrics +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_run_metrics_creation(server: SyncServer, sarah_agent, default_user): + """Test that run metrics are created when a run is created.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test_metrics"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Get the run metrics + metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + + # Assertions + assert metrics is not None + assert metrics.id == created_run.id + assert metrics.agent_id == sarah_agent.id + assert metrics.organization_id == default_user.organization_id + # project_id may be None or set from the agent + assert metrics.run_start_ns is not None + assert metrics.run_start_ns > 0 + assert metrics.run_ns is None # Should be None until run completes + assert metrics.num_steps is not None + assert metrics.num_steps == 0 # Should be 0 initially + + +@pytest.mark.asyncio +async def test_run_metrics_timestamp_tracking(server: SyncServer, sarah_agent, default_user): + """Test that run_start_ns is properly tracked.""" + import time + + # Record time before creation + before_ns = int(time.time() * 1e9) + + # Create a run + run_data = PydanticRun( + metadata={"type": "test_timestamp"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Record time after creation + after_ns = int(time.time() * 1e9) + + # Get the run metrics + metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + + # Verify timestamp is within expected range + assert metrics.run_start_ns is not None + assert before_ns <= metrics.run_start_ns <= after_ns, f"Expected {before_ns} <= {metrics.run_start_ns} <= {after_ns}" + + +@pytest.mark.asyncio +async def test_run_metrics_duration_calculation(server: SyncServer, sarah_agent, default_user): + """Test that run duration (run_ns) is calculated when run completes.""" + import asyncio + + # Create a run + run_data = PydanticRun( + metadata={"type": "test_duration"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Get initial metrics + initial_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + assert initial_metrics.run_ns is None # Should be None initially + assert initial_metrics.run_start_ns is not None + + # Wait a bit to ensure there's measurable duration + await asyncio.sleep(0.1) # Wait 100ms + + # Update the run to completed + updated_run = await server.run_manager.update_run_by_id_async( + created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user + ) + + # Get updated metrics + final_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + + # Assertions + assert final_metrics.run_ns is not None + assert final_metrics.run_ns > 0 + # Duration should be at least 100ms (100_000_000 nanoseconds) + assert final_metrics.run_ns >= 100_000_000, f"Expected run_ns >= 100_000_000, got {final_metrics.run_ns}" + # Duration should be reasonable (less than 10 seconds) + assert final_metrics.run_ns < 10_000_000_000, f"Expected run_ns < 10_000_000_000, got {final_metrics.run_ns}" + + +@pytest.mark.asyncio +async def test_run_metrics_num_steps_tracking(server: SyncServer, sarah_agent, default_user): + """Test that num_steps is properly tracked in run metrics.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test_num_steps"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Initial metrics should have 0 steps + initial_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + assert initial_metrics.num_steps == 0 + + # Add some steps + for i in range(3): + await server.step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=100 + i * 10, + prompt_tokens=50 + i * 5, + total_tokens=150 + i * 15, + ), + run_id=created_run.id, + actor=default_user, + project_id=sarah_agent.project_id, + ) + + # Update the run to trigger metrics update + await server.run_manager.update_run_by_id_async( + created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user + ) + + # Get updated metrics + final_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + + # Verify num_steps was updated + assert final_metrics.num_steps == 3 + + +@pytest.mark.asyncio +async def test_run_metrics_not_found(server: SyncServer, default_user): + """Test getting metrics for non-existent run.""" + with pytest.raises(NoResultFound): + await server.run_manager.get_run_metrics_async(run_id="nonexistent_run", actor=default_user) + + +@pytest.mark.asyncio +async def test_run_metrics_partial_update(server: SyncServer, sarah_agent, default_user): + """Test that non-terminal updates don't calculate run_ns.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test_partial"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Add a step + await server.step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=100, + prompt_tokens=50, + total_tokens=150, + ), + run_id=created_run.id, + actor=default_user, + project_id=sarah_agent.project_id, + ) + + # Update to running (non-terminal) + await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.running), actor=default_user) + + # Get metrics + metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + + # Verify run_ns is still None (not calculated for non-terminal updates) + assert metrics.run_ns is None + # But num_steps should be updated + assert metrics.num_steps == 1 + + +@pytest.mark.asyncio +async def test_run_metrics_integration_with_run_steps(server: SyncServer, sarah_agent, default_user): + """Test integration between run metrics and run steps.""" + # Create a run + run_data = PydanticRun( + metadata={"type": "test_integration"}, + agent_id=sarah_agent.id, + ) + created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user) + + # Add multiple steps + step_ids = [] + for i in range(5): + step = await server.step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=100, + prompt_tokens=50, + total_tokens=150, + ), + run_id=created_run.id, + actor=default_user, + project_id=sarah_agent.project_id, + ) + step_ids.append(step.id) + + # Get run steps + run_steps = await server.run_manager.get_run_steps(run_id=created_run.id, actor=default_user) + + # Verify steps are returned correctly + assert len(run_steps) == 5 + assert all(step.run_id == created_run.id for step in run_steps) + + # Update run to completed + await server.run_manager.update_run_by_id_async( + created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user + ) + + # Get final metrics + metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user) + + # Verify metrics reflect the steps + assert metrics.num_steps == 5 + assert metrics.run_ns is not None + + # TODO: add back once metrics are added # @pytest.mark.asyncio