feat: add schema/db for new steps table (#669)

This commit is contained in:
cthomas
2025-01-18 12:20:10 -08:00
committed by GitHub
parent 13dfe4adbd
commit ef6fce8e0f
17 changed files with 466 additions and 84 deletions

View File

@@ -0,0 +1,116 @@
"""Repurpose JobUsageStatistics for new Steps table
Revision ID: 416b9d2db10b
Revises: 25fc99e97839
Create Date: 2025-01-17 11:27:42.115755
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "416b9d2db10b"
down_revision: Union[str, None] = "25fc99e97839"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Rename the table
op.rename_table("job_usage_statistics", "steps")
# Rename the foreign key constraint and drop non-null constraint
op.alter_column("steps", "job_id", nullable=True)
op.drop_constraint("fk_job_usage_statistics_job_id", "steps", type_="foreignkey")
# Change id field from int to string
op.execute("ALTER TABLE steps RENAME COLUMN id TO old_id")
op.add_column("steps", sa.Column("id", sa.String(), nullable=True))
op.execute("""UPDATE steps SET id = 'step-' || gen_random_uuid()::text""")
op.drop_column("steps", "old_id")
op.alter_column("steps", "id", nullable=False)
op.create_primary_key("pk_steps_id", "steps", ["id"])
# Add new columns
op.add_column("steps", sa.Column("origin", sa.String(), nullable=True))
op.add_column("steps", sa.Column("organization_id", sa.String(), nullable=True))
op.add_column("steps", sa.Column("provider_id", sa.String(), nullable=True))
op.add_column("steps", sa.Column("provider_name", sa.String(), nullable=True))
op.add_column("steps", sa.Column("model", sa.String(), nullable=True))
op.add_column("steps", sa.Column("context_window_limit", sa.Integer(), nullable=True))
op.add_column(
"steps",
sa.Column("completion_tokens_details", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
)
op.add_column(
"steps",
sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
)
op.add_column("steps", sa.Column("tid", sa.String(), nullable=True))
# Add new foreign key constraint for provider_id
op.create_foreign_key("fk_steps_organization_id", "steps", "providers", ["provider_id"], ["id"], ondelete="RESTRICT")
# Add new foreign key constraint for provider_id
op.create_foreign_key("fk_steps_provider_id", "steps", "organizations", ["organization_id"], ["id"], ondelete="RESTRICT")
# Add new foreign key constraint for provider_id
op.create_foreign_key("fk_steps_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="SET NULL")
# Drop old step_id and step_count columns which aren't in the new model
op.drop_column("steps", "step_id")
op.drop_column("steps", "step_count")
# Add step_id to messages table
op.add_column("messages", sa.Column("step_id", sa.String(), nullable=True))
op.create_foreign_key("fk_messages_step_id", "messages", "steps", ["step_id"], ["id"], ondelete="SET NULL")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Remove step_id from messages first to avoid foreign key conflicts
op.drop_constraint("fk_messages_step_id", "messages", type_="foreignkey")
op.drop_column("messages", "step_id")
# Restore old step_count and step_id column
op.add_column("steps", sa.Column("step_count", sa.Integer(), nullable=True))
op.add_column("steps", sa.Column("step_id", sa.String(), nullable=True))
# Drop new columns and constraints
op.drop_constraint("fk_steps_provider_id", "steps", type_="foreignkey")
op.drop_constraint("fk_steps_organization_id", "steps", type_="foreignkey")
op.drop_constraint("fk_steps_job_id", "steps", type_="foreignkey")
op.drop_column("steps", "tid")
op.drop_column("steps", "tags")
op.drop_column("steps", "completion_tokens_details")
op.drop_column("steps", "context_window_limit")
op.drop_column("steps", "model")
op.drop_column("steps", "provider_name")
op.drop_column("steps", "provider_id")
op.drop_column("steps", "organization_id")
op.drop_column("steps", "origin")
# Add constraints back
op.execute("DELETE FROM steps WHERE job_id IS NULL")
op.alter_column("steps", "job_id", nullable=False)
op.create_foreign_key("fk_job_usage_statistics_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="CASCADE")
# Change id field from string back to int
op.add_column("steps", sa.Column("old_id", sa.Integer(), nullable=True))
op.execute("""UPDATE steps SET old_id = CAST(ABS(hashtext(REPLACE(id, 'step-', '')::text)) AS integer)""")
op.drop_column("steps", "id")
op.execute("ALTER TABLE steps RENAME COLUMN old_id TO id")
op.alter_column("steps", "id", nullable=False)
op.create_primary_key("pk_steps_id", "steps", ["id"])
# Rename the table
op.rename_table("steps", "job_usage_statistics")
# ### end Alembic commands ###

View File

@@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure
from letta.services.job_manager import JobManager from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
from letta.services.provider_manager import ProviderManager
from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
@@ -130,8 +132,10 @@ class Agent(BaseAgent):
# Create the persistence manager object based on the AgentState info # Create the persistence manager object based on the AgentState info
self.message_manager = MessageManager() self.message_manager = MessageManager()
self.passage_manager = PassageManager() self.passage_manager = PassageManager()
self.provider_manager = ProviderManager()
self.agent_manager = AgentManager() self.agent_manager = AgentManager()
self.job_manager = JobManager() self.job_manager = JobManager()
self.step_manager = StepManager()
# State needed for heartbeat pausing # State needed for heartbeat pausing
@@ -764,6 +768,24 @@ class Agent(BaseAgent):
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
) )
# Log step - this must happen before messages are persisted
step = self.step_manager.log_step(
actor=self.user,
provider_name=self.agent_state.llm_config.model_endpoint_type,
model=self.agent_state.llm_config.model,
context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage,
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
provider_id=(
self.provider_manager.get_anthropic_override_provider_id()
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
else None
),
job_id=job_id,
)
for message in all_new_messages:
message.step_id = step.id
# Persisting into Messages # Persisting into Messages
self.agent_state = self.agent_manager.append_to_in_context_messages( self.agent_state = self.agent_manager.append_to_in_context_messages(
all_new_messages, agent_id=self.agent_state.id, actor=self.user all_new_messages, agent_id=self.agent_state.id, actor=self.user

View File

@@ -14,6 +14,7 @@ from letta.schemas.openai.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
) )
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.services.provider_manager import ProviderManager
from letta.settings import model_settings from letta.settings import model_settings
from letta.utils import get_utc_time, smart_urljoin from letta.utils import get_utc_time, smart_urljoin
@@ -39,9 +40,6 @@ MODEL_LIST = [
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence." DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
if model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int: def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
for model_dict in anthropic_get_model_list(url=url, api_key=api_key): for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
@@ -397,6 +395,12 @@ def anthropic_chat_completions_request(
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use""" """https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag) data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
response = anthropic_client.beta.messages.create( response = anthropic_client.beta.messages.create(
**data, **data,

View File

@@ -6,7 +6,6 @@ from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata from letta.orm.file import FileMetadata
from letta.orm.job import Job from letta.orm.job import Job
from letta.orm.job_messages import JobMessage from letta.orm.job_messages import JobMessage
from letta.orm.job_usage_statistics import JobUsageStatistics
from letta.orm.message import Message from letta.orm.message import Message
from letta.orm.organization import Organization from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
@@ -14,6 +13,7 @@ from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents from letta.orm.sources_agents import SourcesAgents
from letta.orm.step import Step
from letta.orm.tool import Tool from letta.orm.tool import Tool
from letta.orm.tools_agents import ToolsAgents from letta.orm.tools_agents import ToolsAgents
from letta.orm.user import User from letta.orm.user import User

View File

@@ -13,8 +13,8 @@ from letta.schemas.letta_request import LettaRequestConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from letta.orm.job_messages import JobMessage from letta.orm.job_messages import JobMessage
from letta.orm.job_usage_statistics import JobUsageStatistics
from letta.orm.message import Message from letta.orm.message import Message
from letta.orm.step import Step
from letta.orm.user import User from letta.orm.user import User
@@ -41,9 +41,7 @@ class Job(SqlalchemyBase, UserMixin):
# relationships # relationships
user: Mapped["User"] = relationship("User", back_populates="jobs") user: Mapped["User"] = relationship("User", back_populates="jobs")
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan") job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
usage_statistics: Mapped[list["JobUsageStatistics"]] = relationship( steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update")
"JobUsageStatistics", back_populates="job", cascade="all, delete-orphan"
)
@property @property
def messages(self) -> List["Message"]: def messages(self) -> List["Message"]:

View File

@@ -1,30 +0,0 @@
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
if TYPE_CHECKING:
from letta.orm.job import Job
class JobUsageStatistics(SqlalchemyBase):
"""Tracks usage statistics for jobs, with future support for per-step tracking."""
__tablename__ = "job_usage_statistics"
id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the usage statistics entry")
job_id: Mapped[str] = mapped_column(
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the job these statistics belong to"
)
step_id: Mapped[Optional[str]] = mapped_column(
nullable=True, doc="ID of the specific step within the job (for future per-step tracking)"
)
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
step_count: Mapped[int] = mapped_column(default=0, doc="Number of steps taken by the agent")
# Relationship back to the job
job: Mapped["Job"] = relationship("Job", back_populates="usage_statistics")

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from sqlalchemy import Index from sqlalchemy import ForeignKey, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import ToolCallColumn from letta.orm.custom_columns import ToolCallColumn
@@ -24,10 +24,14 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios") name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information") tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call") tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
step_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("steps.id", ondelete="SET NULL"), nullable=True, doc="ID of the step that this message belongs to"
)
# Relationships # Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin") organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
# Job relationship # Job relationship
job_message: Mapped[Optional["JobMessage"]] = relationship( job_message: Mapped[Optional["JobMessage"]] = relationship(

54
letta/orm/step.py Normal file
View File

@@ -0,0 +1,54 @@
import uuid
from typing import TYPE_CHECKING, Dict, List, Optional
from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.step import Step as PydanticStep
if TYPE_CHECKING:
from letta.orm.job import Job
from letta.orm.provider import Provider
class Step(SqlalchemyBase):
"""Tracks all metadata for agent step."""
__tablename__ = "steps"
__pydantic_model__ = PydanticStep
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}")
origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.")
organization_id: Mapped[str] = mapped_column(
ForeignKey("organizations.id", ondelete="RESTRICT"),
nullable=True,
doc="The unique identifier of the organization that this step ran for",
)
provider_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("providers.id", ondelete="RESTRICT"),
nullable=True,
doc="The unique identifier of the provider that was configured for this step",
)
job_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
)
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
context_window_limit: Mapped[Optional[int]] = mapped_column(
None, nullable=True, doc="The context window limit configured for this step."
)
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
# Relationships (foreign keys)
organization: Mapped[Optional["Organization"]] = relationship("Organization")
provider: Mapped[Optional["Provider"]] = relationship("Provider")
job: Mapped[Optional["Job"]] = relationship("Job", back_populates="steps")
# Relationships (backrefs)
messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")

View File

@@ -99,6 +99,7 @@ class Message(BaseMessage):
name: Optional[str] = Field(None, description="The name of the participant.") name: Optional[str] = Field(None, description="The name of the participant.")
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.") tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects # This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

31
letta/schemas/step.py Normal file
View File

@@ -0,0 +1,31 @@
from typing import Dict, List, Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.schemas.message import Message
class StepBase(LettaBase):
__id_prefix__ = "step"
class Step(StepBase):
id: str = Field(..., description="The id of the step. Assigned by the database.")
origin: Optional[str] = Field(None, description="The surface that this agent step was initiated from.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the step.")
provider_id: Optional[str] = Field(None, description="The unique identifier of the provider that was configured for this step")
job_id: Optional[str] = Field(
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
)
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
model: Optional[str] = Field(None, description="The name of the model used for this step.")
context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.")
completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.")
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
tags: List[str] = Field([], description="Metadata tags.")
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
messages: List[Message] = Field([], description="The messages generated during this step.")

View File

@@ -595,9 +595,6 @@ async def process_message_background(
) )
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor) server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
# Add job usage statistics
server.job_manager.add_job_usage(job_id=job_id, usage=result.usage, actor=actor)
except Exception as e: except Exception as e:
# Update job status to failed # Update job status to failed
job_update = JobUpdate( job_update = JobUpdate(

View File

@@ -72,6 +72,7 @@ from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.provider_manager import ProviderManager from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager from letta.services.source_manager import SourceManager
from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager from letta.services.user_manager import UserManager
@@ -293,6 +294,7 @@ class SyncServer(Server):
self.job_manager = JobManager() self.job_manager = JobManager()
self.agent_manager = AgentManager() self.agent_manager = AgentManager()
self.provider_manager = ProviderManager() self.provider_manager = ProviderManager()
self.step_manager = StepManager()
# Managers that interface with parallelism # Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager() self.per_agent_lock_manager = PerAgentLockManager()

View File

@@ -1,3 +1,5 @@
from functools import reduce
from operator import add
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from sqlalchemy import select from sqlalchemy import select
@@ -7,9 +9,9 @@ from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.orm.job import Job as JobModel from letta.orm.job import Job as JobModel
from letta.orm.job_messages import JobMessage from letta.orm.job_messages import JobMessage
from letta.orm.job_usage_statistics import JobUsageStatistics
from letta.orm.message import Message as MessageModel from letta.orm.message import Message as MessageModel
from letta.orm.sqlalchemy_base import AccessType from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step
from letta.schemas.enums import JobStatus, MessageRole from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.job import Job as PydanticJob from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate from letta.schemas.job import JobUpdate
@@ -193,12 +195,7 @@ class JobManager:
self._verify_job_access(session, job_id, actor) self._verify_job_access(session, job_id, actor)
# Get the latest usage statistics for the job # Get the latest usage statistics for the job
latest_stats = ( latest_stats = session.query(Step).filter(Step.job_id == job_id).order_by(Step.created_at.desc()).all()
session.query(JobUsageStatistics)
.filter(JobUsageStatistics.job_id == job_id)
.order_by(JobUsageStatistics.created_at.desc())
.first()
)
if not latest_stats: if not latest_stats:
return LettaUsageStatistics( return LettaUsageStatistics(
@@ -209,10 +206,10 @@ class JobManager:
) )
return LettaUsageStatistics( return LettaUsageStatistics(
completion_tokens=latest_stats.completion_tokens, completion_tokens=reduce(add, (step.completion_tokens or 0 for step in latest_stats), 0),
prompt_tokens=latest_stats.prompt_tokens, prompt_tokens=reduce(add, (step.prompt_tokens or 0 for step in latest_stats), 0),
total_tokens=latest_stats.total_tokens, total_tokens=reduce(add, (step.total_tokens or 0 for step in latest_stats), 0),
step_count=latest_stats.step_count, step_count=len(latest_stats),
) )
@enforce_types @enforce_types
@@ -239,8 +236,9 @@ class JobManager:
# First verify job exists and user has access # First verify job exists and user has access
self._verify_job_access(session, job_id, actor, access=["write"]) self._verify_job_access(session, job_id, actor, access=["write"])
# Create new usage statistics entry # Manually log step with usage data
usage_stats = JobUsageStatistics( # TODO(@caren): log step under the hood and remove this
usage_stats = Step(
job_id=job_id, job_id=job_id,
completion_tokens=usage.completion_tokens, completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens, prompt_tokens=usage.prompt_tokens,

View File

@@ -48,9 +48,13 @@ class ProviderManager:
def delete_provider_by_id(self, provider_id: str): def delete_provider_by_id(self, provider_id: str):
"""Delete a provider.""" """Delete a provider."""
with self.session_maker() as session: with self.session_maker() as session:
# Delete from provider table # Clear api key field
provider = ProviderModel.read(db_session=session, identifier=provider_id) existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
provider.hard_delete(session) existing_provider.api_key = None
existing_provider.update(session)
# Soft delete in provider table
existing_provider.delete(session)
session.commit() session.commit()
@@ -62,9 +66,17 @@ class ProviderManager:
return [provider.to_pydantic() for provider in results] return [provider.to_pydantic() for provider in results]
@enforce_types @enforce_types
def get_anthropic_key_override(self) -> Optional[str]: def get_anthropic_override_provider_id(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature""" """Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
providers = self.list_providers(limit=1) anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(providers) == 1 and providers[0].name == "anthropic": if len(anthropic_provider) != 0:
return providers[0].api_key return anthropic_provider[0].id
return None
@enforce_types
def get_anthropic_override_key(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].api_key
return None return None

View File

@@ -0,0 +1,87 @@
from typing import List, Literal, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
from letta.orm.errors import NoResultFound
from letta.orm.job import Job as JobModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class StepManager:
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def log_step(
self,
actor: PydanticUser,
provider_name: str,
model: str,
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"provider_id": provider_id,
"provider_name": provider_name,
"model": model,
"context_window_limit": context_window_limit,
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"job_id": job_id,
"tags": [],
"tid": None,
}
with self.session_maker() as session:
if job_id:
self._verify_job_access(session, job_id, actor, access=["write"])
new_step = StepModel(**step_data)
new_step.create(session)
return new_step.to_pydantic()
@enforce_types
def get_step(self, step_id: str) -> PydanticStep:
with self.session_maker() as session:
step = StepModel.read(db_session=session, identifier=step_id)
return step.to_pydantic()
def _verify_job_access(
self,
session: Session,
job_id: str,
actor: PydanticUser,
access: List[Literal["read", "write", "delete"]] = ["read"],
) -> JobModel:
"""
Verify that a job exists and the user has the required access.
Args:
session: The database session
job_id: The ID of the job to verify
actor: The user making the request
Returns:
The job if it exists and the user has access
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
job_query = select(JobModel).where(JobModel.id == job_id)
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
job = session.execute(job_query).scalar_one_or_none()
if not job:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job

View File

@@ -26,6 +26,7 @@ from letta.orm import (
Source, Source,
SourcePassage, SourcePassage,
SourcesAgents, SourcesAgents,
Step,
Tool, Tool,
ToolsAgents, ToolsAgents,
User, User,
@@ -46,6 +47,7 @@ from letta.schemas.letta_request import LettaRequestConfig
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate from letta.schemas.message import MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.passage import Passage as PydanticPassage
@@ -56,7 +58,6 @@ from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate from letta.schemas.tool import ToolUpdate
from letta.schemas.tool_rule import InitToolRule from letta.schemas.tool_rule import InitToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate from letta.schemas.user import UserUpdate
from letta.server.server import SyncServer from letta.server.server import SyncServer
@@ -100,6 +101,7 @@ def clear_tables(server: SyncServer):
session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(Tool)) # Clear all records from the Tool table
session.execute(delete(Agent)) session.execute(delete(Agent))
session.execute(delete(User)) # Clear all records from the user table session.execute(delete(User)) # Clear all records from the user table
session.execute(delete(Step))
session.execute(delete(Provider)) session.execute(delete(Provider))
session.execute(delete(Organization)) # Clear all records from the organization table session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion session.commit() # Commit the deletion
@@ -2835,17 +2837,19 @@ def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser,
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user): def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
"""Test adding and retrieving job usage statistics.""" """Test adding and retrieving job usage statistics."""
job_manager = server.job_manager job_manager = server.job_manager
step_manager = server.step_manager
# Add usage statistics # Add usage statistics
job_manager.add_job_usage( step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id=default_job.id, job_id=default_job.id,
usage=LettaUsageStatistics( usage=UsageStatistics(
completion_tokens=100, completion_tokens=100,
prompt_tokens=50, prompt_tokens=50,
total_tokens=150, total_tokens=150,
step_count=5,
), ),
step_id="step_1",
actor=default_user, actor=default_user,
) )
@@ -2874,30 +2878,33 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user): def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
"""Test adding multiple usage statistics entries for a job.""" """Test adding multiple usage statistics entries for a job."""
job_manager = server.job_manager job_manager = server.job_manager
step_manager = server.step_manager
# Add first usage statistics entry # Add first usage statistics entry
job_manager.add_job_usage( step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id=default_job.id, job_id=default_job.id,
usage=LettaUsageStatistics( usage=UsageStatistics(
completion_tokens=100, completion_tokens=100,
prompt_tokens=50, prompt_tokens=50,
total_tokens=150, total_tokens=150,
step_count=5,
), ),
step_id="step_1",
actor=default_user, actor=default_user,
) )
# Add second usage statistics entry # Add second usage statistics entry
job_manager.add_job_usage( step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id=default_job.id, job_id=default_job.id,
usage=LettaUsageStatistics( usage=UsageStatistics(
completion_tokens=200, completion_tokens=200,
prompt_tokens=100, prompt_tokens=100,
total_tokens=300, total_tokens=300,
step_count=10,
), ),
step_id="step_2",
actor=default_user, actor=default_user,
) )
@@ -2905,9 +2912,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user) usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user)
# Verify we get the most recent statistics # Verify we get the most recent statistics
assert usage_stats.completion_tokens == 200 assert usage_stats.completion_tokens == 300
assert usage_stats.prompt_tokens == 100 assert usage_stats.prompt_tokens == 150
assert usage_stats.total_tokens == 300 assert usage_stats.total_tokens == 450
assert usage_stats.step_count == 2
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
@@ -2920,18 +2928,19 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user): def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
"""Test adding usage statistics for a nonexistent job.""" """Test adding usage statistics for a nonexistent job."""
job_manager = server.job_manager step_manager = server.step_manager
with pytest.raises(NoResultFound): with pytest.raises(NoResultFound):
job_manager.add_job_usage( step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id="nonexistent_job", job_id="nonexistent_job",
usage=LettaUsageStatistics( usage=UsageStatistics(
completion_tokens=100, completion_tokens=100,
prompt_tokens=50, prompt_tokens=50,
total_tokens=150, total_tokens=150,
step_count=5,
), ),
step_id="step_1",
actor=default_user, actor=default_user,
) )

View File

@@ -1,15 +1,19 @@
import json import json
import os
import uuid import uuid
import warnings import warnings
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
from sqlalchemy import delete
import letta.utils as utils import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.user import User from letta.schemas.user import User
utils.DEBUG = True utils.DEBUG = True
@@ -277,6 +281,10 @@ def org_id(server):
yield org.id yield org.id
# cleanup # cleanup
with server.organization_manager.session_maker() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
server.organization_manager.delete_organization_by_id(org.id) server.organization_manager.delete_organization_by_id(org.id)
@@ -1098,3 +1106,72 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
request.tool_ids = [b.id for b in base_tools[:-2]] request.tool_ids = [b.id for b in base_tools[:-2]]
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor) agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
assert len(agent_state.tools) == len(base_tools) - 2 assert len(agent_state.tools) == len(base_tools) - 2
def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
provider=PydanticProvider(
name="anthropic",
api_key=os.getenv("ANTHROPIC_API_KEY"),
),
actor=actor,
)
agent = server.create_agent(
request=CreateAgent(
memory_blocks=[], llm="anthropic/claude-3-opus-20240229", context_window_limit=200000, embedding="openai/text-embedding-ada-002"
),
actor=actor,
)
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
assert usage, "Sending message failed"
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
assert len(get_messages_response) > 0, "Retrieving messages failed"
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = server.step_manager.get_step(step_id=step_id)
assert step, "Step was not logged correctly"
assert step.provider_id == provider.id
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window
completion_tokens += int(step.completion_tokens)
prompt_tokens += int(step.prompt_tokens)
total_tokens += int(step.total_tokens)
assert completion_tokens == usage.completion_tokens
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens
server.provider_manager.delete_provider_by_id(provider.id)
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
assert usage, "Sending message failed"
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
assert len(get_messages_response) > 0, "Retrieving messages failed"
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = server.step_manager.get_step(step_id=step_id)
assert step, "Step was not logged correctly"
assert step.provider_id == None
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window
completion_tokens += int(step.completion_tokens)
prompt_tokens += int(step.prompt_tokens)
total_tokens += int(step.total_tokens)
assert completion_tokens == usage.completion_tokens
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens