feat: add schema/db for new steps table (#669)
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
"""Repurpose JobUsageStatistics for new Steps table
|
||||
|
||||
Revision ID: 416b9d2db10b
|
||||
Revises: 25fc99e97839
|
||||
Create Date: 2025-01-17 11:27:42.115755
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "416b9d2db10b"
|
||||
down_revision: Union[str, None] = "25fc99e97839"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Rename the table
|
||||
op.rename_table("job_usage_statistics", "steps")
|
||||
|
||||
# Rename the foreign key constraint and drop non-null constraint
|
||||
op.alter_column("steps", "job_id", nullable=True)
|
||||
op.drop_constraint("fk_job_usage_statistics_job_id", "steps", type_="foreignkey")
|
||||
|
||||
# Change id field from int to string
|
||||
op.execute("ALTER TABLE steps RENAME COLUMN id TO old_id")
|
||||
op.add_column("steps", sa.Column("id", sa.String(), nullable=True))
|
||||
op.execute("""UPDATE steps SET id = 'step-' || gen_random_uuid()::text""")
|
||||
op.drop_column("steps", "old_id")
|
||||
op.alter_column("steps", "id", nullable=False)
|
||||
op.create_primary_key("pk_steps_id", "steps", ["id"])
|
||||
|
||||
# Add new columns
|
||||
op.add_column("steps", sa.Column("origin", sa.String(), nullable=True))
|
||||
op.add_column("steps", sa.Column("organization_id", sa.String(), nullable=True))
|
||||
op.add_column("steps", sa.Column("provider_id", sa.String(), nullable=True))
|
||||
op.add_column("steps", sa.Column("provider_name", sa.String(), nullable=True))
|
||||
op.add_column("steps", sa.Column("model", sa.String(), nullable=True))
|
||||
op.add_column("steps", sa.Column("context_window_limit", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"steps",
|
||||
sa.Column("completion_tokens_details", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"steps",
|
||||
sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column("steps", sa.Column("tid", sa.String(), nullable=True))
|
||||
|
||||
# Add new foreign key constraint for provider_id
|
||||
op.create_foreign_key("fk_steps_organization_id", "steps", "providers", ["provider_id"], ["id"], ondelete="RESTRICT")
|
||||
|
||||
# Add new foreign key constraint for provider_id
|
||||
op.create_foreign_key("fk_steps_provider_id", "steps", "organizations", ["organization_id"], ["id"], ondelete="RESTRICT")
|
||||
|
||||
# Add new foreign key constraint for provider_id
|
||||
op.create_foreign_key("fk_steps_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="SET NULL")
|
||||
|
||||
# Drop old step_id and step_count columns which aren't in the new model
|
||||
op.drop_column("steps", "step_id")
|
||||
op.drop_column("steps", "step_count")
|
||||
|
||||
# Add step_id to messages table
|
||||
op.add_column("messages", sa.Column("step_id", sa.String(), nullable=True))
|
||||
op.create_foreign_key("fk_messages_step_id", "messages", "steps", ["step_id"], ["id"], ondelete="SET NULL")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Remove step_id from messages first to avoid foreign key conflicts
|
||||
op.drop_constraint("fk_messages_step_id", "messages", type_="foreignkey")
|
||||
op.drop_column("messages", "step_id")
|
||||
|
||||
# Restore old step_count and step_id column
|
||||
op.add_column("steps", sa.Column("step_count", sa.Integer(), nullable=True))
|
||||
op.add_column("steps", sa.Column("step_id", sa.String(), nullable=True))
|
||||
|
||||
# Drop new columns and constraints
|
||||
op.drop_constraint("fk_steps_provider_id", "steps", type_="foreignkey")
|
||||
op.drop_constraint("fk_steps_organization_id", "steps", type_="foreignkey")
|
||||
op.drop_constraint("fk_steps_job_id", "steps", type_="foreignkey")
|
||||
|
||||
op.drop_column("steps", "tid")
|
||||
op.drop_column("steps", "tags")
|
||||
op.drop_column("steps", "completion_tokens_details")
|
||||
op.drop_column("steps", "context_window_limit")
|
||||
op.drop_column("steps", "model")
|
||||
op.drop_column("steps", "provider_name")
|
||||
op.drop_column("steps", "provider_id")
|
||||
op.drop_column("steps", "organization_id")
|
||||
op.drop_column("steps", "origin")
|
||||
|
||||
# Add constraints back
|
||||
op.execute("DELETE FROM steps WHERE job_id IS NULL")
|
||||
op.alter_column("steps", "job_id", nullable=False)
|
||||
op.create_foreign_key("fk_job_usage_statistics_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="CASCADE")
|
||||
|
||||
# Change id field from string back to int
|
||||
op.add_column("steps", sa.Column("old_id", sa.Integer(), nullable=True))
|
||||
op.execute("""UPDATE steps SET old_id = CAST(ABS(hashtext(REPLACE(id, 'step-', '')::text)) AS integer)""")
|
||||
op.drop_column("steps", "id")
|
||||
op.execute("ALTER TABLE steps RENAME COLUMN old_id TO id")
|
||||
op.alter_column("steps", "id", nullable=False)
|
||||
op.create_primary_key("pk_steps_id", "steps", ["id"])
|
||||
|
||||
# Rename the table
|
||||
op.rename_table("steps", "job_usage_statistics")
|
||||
# ### end Alembic commands ###
|
||||
@@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||
@@ -130,8 +132,10 @@ class Agent(BaseAgent):
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.message_manager = MessageManager()
|
||||
self.passage_manager = PassageManager()
|
||||
self.provider_manager = ProviderManager()
|
||||
self.agent_manager = AgentManager()
|
||||
self.job_manager = JobManager()
|
||||
self.step_manager = StepManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
|
||||
@@ -764,6 +768,24 @@ class Agent(BaseAgent):
|
||||
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
|
||||
)
|
||||
|
||||
# Log step - this must happen before messages are persisted
|
||||
step = self.step_manager.log_step(
|
||||
actor=self.user,
|
||||
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
||||
model=self.agent_state.llm_config.model,
|
||||
context_window_limit=self.agent_state.llm_config.context_window,
|
||||
usage=response.usage,
|
||||
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
|
||||
provider_id=(
|
||||
self.provider_manager.get_anthropic_override_provider_id()
|
||||
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
|
||||
else None
|
||||
),
|
||||
job_id=job_id,
|
||||
)
|
||||
for message in all_new_messages:
|
||||
message.step_id = step.id
|
||||
|
||||
# Persisting into Messages
|
||||
self.agent_state = self.agent_manager.append_to_in_context_messages(
|
||||
all_new_messages, agent_id=self.agent_state.id, actor=self.user
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_utc_time, smart_urljoin
|
||||
|
||||
@@ -39,9 +40,6 @@ MODEL_LIST = [
|
||||
|
||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||
|
||||
if model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
|
||||
|
||||
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
|
||||
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
|
||||
@@ -397,6 +395,12 @@ def anthropic_chat_completions_request(
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> ChatCompletionResponse:
|
||||
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
||||
anthropic_client = None
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
|
||||
response = anthropic_client.beta.messages.create(
|
||||
**data,
|
||||
|
||||
@@ -6,7 +6,6 @@ from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.job_usage_statistics import JobUsageStatistics
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
@@ -14,6 +13,7 @@ from letta.orm.provider import Provider
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.orm.step import Step
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
from letta.orm.user import User
|
||||
|
||||
@@ -13,8 +13,8 @@ from letta.schemas.letta_request import LettaRequestConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.job_usage_statistics import JobUsageStatistics
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.step import Step
|
||||
from letta.orm.user import User
|
||||
|
||||
|
||||
@@ -41,9 +41,7 @@ class Job(SqlalchemyBase, UserMixin):
|
||||
# relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="jobs")
|
||||
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
|
||||
usage_statistics: Mapped[list["JobUsageStatistics"]] = relationship(
|
||||
"JobUsageStatistics", back_populates="job", cascade="all, delete-orphan"
|
||||
)
|
||||
steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update")
|
||||
|
||||
@property
|
||||
def messages(self) -> List["Message"]:
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job import Job
|
||||
|
||||
|
||||
class JobUsageStatistics(SqlalchemyBase):
|
||||
"""Tracks usage statistics for jobs, with future support for per-step tracking."""
|
||||
|
||||
__tablename__ = "job_usage_statistics"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the usage statistics entry")
|
||||
job_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the job these statistics belong to"
|
||||
)
|
||||
step_id: Mapped[Optional[str]] = mapped_column(
|
||||
nullable=True, doc="ID of the specific step within the job (for future per-step tracking)"
|
||||
)
|
||||
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
|
||||
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
||||
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
||||
step_count: Mapped[int] = mapped_column(default=0, doc="Number of steps taken by the agent")
|
||||
|
||||
# Relationship back to the job
|
||||
job: Mapped["Job"] = relationship("Job", back_populates="usage_statistics")
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import ToolCallColumn
|
||||
@@ -24,10 +24,14 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
||||
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
|
||||
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
||||
step_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("steps.id", ondelete="SET NULL"), nullable=True, doc="ID of the step that this message belongs to"
|
||||
)
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
||||
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
|
||||
|
||||
# Job relationship
|
||||
job_message: Mapped[Optional["JobMessage"]] = relationship(
|
||||
|
||||
54
letta/orm/step.py
Normal file
54
letta/orm/step.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.provider import Provider
|
||||
|
||||
|
||||
class Step(SqlalchemyBase):
|
||||
"""Tracks all metadata for agent step."""
|
||||
|
||||
__tablename__ = "steps"
|
||||
__pydantic_model__ = PydanticStep
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}")
|
||||
origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.")
|
||||
organization_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("organizations.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
doc="The unique identifier of the organization that this step ran for",
|
||||
)
|
||||
provider_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("providers.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
doc="The unique identifier of the provider that was configured for this step",
|
||||
)
|
||||
job_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
|
||||
)
|
||||
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
|
||||
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
context_window_limit: Mapped[Optional[int]] = mapped_column(
|
||||
None, nullable=True, doc="The context window limit configured for this step."
|
||||
)
|
||||
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
|
||||
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
||||
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
||||
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
||||
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
||||
|
||||
# Relationships (foreign keys)
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization")
|
||||
provider: Mapped[Optional["Provider"]] = relationship("Provider")
|
||||
job: Mapped[Optional["Job"]] = relationship("Job", back_populates="steps")
|
||||
|
||||
# Relationships (backrefs)
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")
|
||||
@@ -99,6 +99,7 @@ class Message(BaseMessage):
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
|
||||
31
letta/schemas/step.py
Normal file
31
letta/schemas/step.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
class StepBase(LettaBase):
|
||||
__id_prefix__ = "step"
|
||||
|
||||
|
||||
class Step(StepBase):
|
||||
id: str = Field(..., description="The id of the step. Assigned by the database.")
|
||||
origin: Optional[str] = Field(None, description="The surface that this agent step was initiated from.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the step.")
|
||||
provider_id: Optional[str] = Field(None, description="The unique identifier of the provider that was configured for this step")
|
||||
job_id: Optional[str] = Field(
|
||||
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
|
||||
)
|
||||
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
|
||||
model: Optional[str] = Field(None, description="The name of the model used for this step.")
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.")
|
||||
completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.")
|
||||
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
|
||||
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
|
||||
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
|
||||
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
||||
messages: List[Message] = Field([], description="The messages generated during this step.")
|
||||
@@ -595,9 +595,6 @@ async def process_message_background(
|
||||
)
|
||||
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
||||
|
||||
# Add job usage statistics
|
||||
server.job_manager.add_job_usage(job_id=job_id, usage=result.usage, actor=actor)
|
||||
|
||||
except Exception as e:
|
||||
# Update job status to failed
|
||||
job_update = JobUpdate(
|
||||
|
||||
@@ -72,6 +72,7 @@ from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
@@ -293,6 +294,7 @@ class SyncServer(Server):
|
||||
self.job_manager = JobManager()
|
||||
self.agent_manager = AgentManager()
|
||||
self.provider_manager = ProviderManager()
|
||||
self.step_manager = StepManager()
|
||||
|
||||
# Managers that interface with parallelism
|
||||
self.per_agent_lock_manager = PerAgentLockManager()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from functools import reduce
|
||||
from operator import add
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -7,9 +9,9 @@ from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.job import Job as JobModel
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.job_usage_statistics import JobUsageStatistics
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
@@ -193,12 +195,7 @@ class JobManager:
|
||||
self._verify_job_access(session, job_id, actor)
|
||||
|
||||
# Get the latest usage statistics for the job
|
||||
latest_stats = (
|
||||
session.query(JobUsageStatistics)
|
||||
.filter(JobUsageStatistics.job_id == job_id)
|
||||
.order_by(JobUsageStatistics.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
latest_stats = session.query(Step).filter(Step.job_id == job_id).order_by(Step.created_at.desc()).all()
|
||||
|
||||
if not latest_stats:
|
||||
return LettaUsageStatistics(
|
||||
@@ -209,10 +206,10 @@ class JobManager:
|
||||
)
|
||||
|
||||
return LettaUsageStatistics(
|
||||
completion_tokens=latest_stats.completion_tokens,
|
||||
prompt_tokens=latest_stats.prompt_tokens,
|
||||
total_tokens=latest_stats.total_tokens,
|
||||
step_count=latest_stats.step_count,
|
||||
completion_tokens=reduce(add, (step.completion_tokens or 0 for step in latest_stats), 0),
|
||||
prompt_tokens=reduce(add, (step.prompt_tokens or 0 for step in latest_stats), 0),
|
||||
total_tokens=reduce(add, (step.total_tokens or 0 for step in latest_stats), 0),
|
||||
step_count=len(latest_stats),
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@@ -239,8 +236,9 @@ class JobManager:
|
||||
# First verify job exists and user has access
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
|
||||
# Create new usage statistics entry
|
||||
usage_stats = JobUsageStatistics(
|
||||
# Manually log step with usage data
|
||||
# TODO(@caren): log step under the hood and remove this
|
||||
usage_stats = Step(
|
||||
job_id=job_id,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
|
||||
@@ -48,9 +48,13 @@ class ProviderManager:
|
||||
def delete_provider_by_id(self, provider_id: str):
|
||||
"""Delete a provider."""
|
||||
with self.session_maker() as session:
|
||||
# Delete from provider table
|
||||
provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||
provider.hard_delete(session)
|
||||
# Clear api key field
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||
existing_provider.api_key = None
|
||||
existing_provider.update(session)
|
||||
|
||||
# Soft delete in provider table
|
||||
existing_provider.delete(session)
|
||||
|
||||
session.commit()
|
||||
|
||||
@@ -62,9 +66,17 @@ class ProviderManager:
|
||||
return [provider.to_pydantic() for provider in results]
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_key_override(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||
providers = self.list_providers(limit=1)
|
||||
if len(providers) == 1 and providers[0].name == "anthropic":
|
||||
return providers[0].api_key
|
||||
def get_anthropic_override_provider_id(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
|
||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||
if len(anthropic_provider) != 0:
|
||||
return anthropic_provider[0].id
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_override_key(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||
if len(anthropic_provider) != 0:
|
||||
return anthropic_provider[0].api_key
|
||||
return None
|
||||
|
||||
87
letta/services/step_manager.py
Normal file
87
letta/services/step_manager.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.job import Job as JobModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class StepManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
"organization_id": actor.organization_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"model": model,
|
||||
"context_window_limit": context_window_limit,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"job_id": job_id,
|
||||
"tags": [],
|
||||
"tid": None,
|
||||
}
|
||||
with self.session_maker() as session:
|
||||
if job_id:
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
new_step = StepModel(**step_data)
|
||||
new_step.create(session)
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_step(self, step_id: str) -> PydanticStep:
|
||||
with self.session_maker() as session:
|
||||
step = StepModel.read(db_session=session, identifier=step_id)
|
||||
return step.to_pydantic()
|
||||
|
||||
def _verify_job_access(
|
||||
self,
|
||||
session: Session,
|
||||
job_id: str,
|
||||
actor: PydanticUser,
|
||||
access: List[Literal["read", "write", "delete"]] = ["read"],
|
||||
) -> JobModel:
|
||||
"""
|
||||
Verify that a job exists and the user has the required access.
|
||||
|
||||
Args:
|
||||
session: The database session
|
||||
job_id: The ID of the job to verify
|
||||
actor: The user making the request
|
||||
|
||||
Returns:
|
||||
The job if it exists and the user has access
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
job_query = select(JobModel).where(JobModel.id == job_id)
|
||||
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
|
||||
job = session.execute(job_query).scalar_one_or_none()
|
||||
if not job:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
@@ -26,6 +26,7 @@ from letta.orm import (
|
||||
Source,
|
||||
SourcePassage,
|
||||
SourcesAgents,
|
||||
Step,
|
||||
Tool,
|
||||
ToolsAgents,
|
||||
User,
|
||||
@@ -46,6 +47,7 @@ from letta.schemas.letta_request import LettaRequestConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
@@ -56,7 +58,6 @@ from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.server import SyncServer
|
||||
@@ -100,6 +101,7 @@ def clear_tables(server: SyncServer):
|
||||
session.execute(delete(Tool)) # Clear all records from the Tool table
|
||||
session.execute(delete(Agent))
|
||||
session.execute(delete(User)) # Clear all records from the user table
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||
session.commit() # Commit the deletion
|
||||
@@ -2835,17 +2837,19 @@ def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser,
|
||||
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
|
||||
"""Test adding and retrieving job usage statistics."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
# Add usage statistics
|
||||
job_manager.add_job_usage(
|
||||
step_manager.log_step(
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
context_window_limit=8192,
|
||||
job_id=default_job.id,
|
||||
usage=LettaUsageStatistics(
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
step_count=5,
|
||||
),
|
||||
step_id="step_1",
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -2874,30 +2878,33 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
|
||||
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
|
||||
"""Test adding multiple usage statistics entries for a job."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
# Add first usage statistics entry
|
||||
job_manager.add_job_usage(
|
||||
step_manager.log_step(
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
context_window_limit=8192,
|
||||
job_id=default_job.id,
|
||||
usage=LettaUsageStatistics(
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
step_count=5,
|
||||
),
|
||||
step_id="step_1",
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add second usage statistics entry
|
||||
job_manager.add_job_usage(
|
||||
step_manager.log_step(
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
context_window_limit=8192,
|
||||
job_id=default_job.id,
|
||||
usage=LettaUsageStatistics(
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=200,
|
||||
prompt_tokens=100,
|
||||
total_tokens=300,
|
||||
step_count=10,
|
||||
),
|
||||
step_id="step_2",
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -2905,9 +2912,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
||||
usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user)
|
||||
|
||||
# Verify we get the most recent statistics
|
||||
assert usage_stats.completion_tokens == 200
|
||||
assert usage_stats.prompt_tokens == 100
|
||||
assert usage_stats.total_tokens == 300
|
||||
assert usage_stats.completion_tokens == 300
|
||||
assert usage_stats.prompt_tokens == 150
|
||||
assert usage_stats.total_tokens == 450
|
||||
assert usage_stats.step_count == 2
|
||||
|
||||
|
||||
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
@@ -2920,18 +2928,19 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
|
||||
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
|
||||
"""Test adding usage statistics for a nonexistent job."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
job_manager.add_job_usage(
|
||||
step_manager.log_step(
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
context_window_limit=8192,
|
||||
job_id="nonexistent_job",
|
||||
usage=LettaUsageStatistics(
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
step_count=5,
|
||||
),
|
||||
step_id="step_1",
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.user import User
|
||||
|
||||
utils.DEBUG = True
|
||||
@@ -277,6 +281,10 @@ def org_id(server):
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
server.organization_manager.delete_organization_by_id(org.id)
|
||||
|
||||
|
||||
@@ -1098,3 +1106,72 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
request.tool_ids = [b.id for b in base_tools[:-2]]
|
||||
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
|
||||
assert len(agent_state.tools) == len(base_tools) - 2
|
||||
|
||||
|
||||
def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
provider=PydanticProvider(
|
||||
name="anthropic",
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
memory_blocks=[], llm="anthropic/claude-3-opus-20240229", context_window_limit=200000, embedding="openai/text-embedding-ada-002"
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
|
||||
|
||||
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
|
||||
assert usage, "Sending message failed"
|
||||
|
||||
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
|
||||
assert len(get_messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||
for step_id in step_ids:
|
||||
step = server.step_manager.get_step(step_id=step_id)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == provider.id
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
assert step.model == agent.llm_config.model
|
||||
assert step.context_window_limit == agent.llm_config.context_window
|
||||
completion_tokens += int(step.completion_tokens)
|
||||
prompt_tokens += int(step.prompt_tokens)
|
||||
total_tokens += int(step.total_tokens)
|
||||
|
||||
assert completion_tokens == usage.completion_tokens
|
||||
assert prompt_tokens == usage.prompt_tokens
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
server.provider_manager.delete_provider_by_id(provider.id)
|
||||
|
||||
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
|
||||
|
||||
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
|
||||
assert usage, "Sending message failed"
|
||||
|
||||
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
|
||||
assert len(get_messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||
for step_id in step_ids:
|
||||
step = server.step_manager.get_step(step_id=step_id)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == None
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
assert step.model == agent.llm_config.model
|
||||
assert step.context_window_limit == agent.llm_config.context_window
|
||||
completion_tokens += int(step.completion_tokens)
|
||||
prompt_tokens += int(step.prompt_tokens)
|
||||
total_tokens += int(step.total_tokens)
|
||||
|
||||
assert completion_tokens == usage.completion_tokens
|
||||
assert prompt_tokens == usage.prompt_tokens
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
Reference in New Issue
Block a user