feat: add schema/db for new steps table (#669)

This commit is contained in:
cthomas
2025-01-18 12:20:10 -08:00
committed by GitHub
parent 8583c72cb5
commit 551cc0820c
17 changed files with 466 additions and 84 deletions

View File

@@ -0,0 +1,116 @@
"""Repurpose JobUsageStatistics for new Steps table
Revision ID: 416b9d2db10b
Revises: 25fc99e97839
Create Date: 2025-01-17 11:27:42.115755
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "416b9d2db10b"
down_revision: Union[str, None] = "25fc99e97839"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Rename the table
op.rename_table("job_usage_statistics", "steps")
# Rename the foreign key constraint and drop non-null constraint
op.alter_column("steps", "job_id", nullable=True)
op.drop_constraint("fk_job_usage_statistics_job_id", "steps", type_="foreignkey")
# Change id field from int to string
op.execute("ALTER TABLE steps RENAME COLUMN id TO old_id")
op.add_column("steps", sa.Column("id", sa.String(), nullable=True))
op.execute("""UPDATE steps SET id = 'step-' || gen_random_uuid()::text""")
op.drop_column("steps", "old_id")
op.alter_column("steps", "id", nullable=False)
op.create_primary_key("pk_steps_id", "steps", ["id"])
# Add new columns
op.add_column("steps", sa.Column("origin", sa.String(), nullable=True))
op.add_column("steps", sa.Column("organization_id", sa.String(), nullable=True))
op.add_column("steps", sa.Column("provider_id", sa.String(), nullable=True))
op.add_column("steps", sa.Column("provider_name", sa.String(), nullable=True))
op.add_column("steps", sa.Column("model", sa.String(), nullable=True))
op.add_column("steps", sa.Column("context_window_limit", sa.Integer(), nullable=True))
op.add_column(
"steps",
sa.Column("completion_tokens_details", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
)
op.add_column(
"steps",
sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
)
op.add_column("steps", sa.Column("tid", sa.String(), nullable=True))
# Add new foreign key constraint for provider_id
op.create_foreign_key("fk_steps_organization_id", "steps", "providers", ["provider_id"], ["id"], ondelete="RESTRICT")
# Add new foreign key constraint for provider_id
op.create_foreign_key("fk_steps_provider_id", "steps", "organizations", ["organization_id"], ["id"], ondelete="RESTRICT")
# Add new foreign key constraint for provider_id
op.create_foreign_key("fk_steps_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="SET NULL")
# Drop old step_id and step_count columns which aren't in the new model
op.drop_column("steps", "step_id")
op.drop_column("steps", "step_count")
# Add step_id to messages table
op.add_column("messages", sa.Column("step_id", sa.String(), nullable=True))
op.create_foreign_key("fk_messages_step_id", "messages", "steps", ["step_id"], ["id"], ondelete="SET NULL")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Remove step_id from messages first to avoid foreign key conflicts
op.drop_constraint("fk_messages_step_id", "messages", type_="foreignkey")
op.drop_column("messages", "step_id")
# Restore old step_count and step_id column
op.add_column("steps", sa.Column("step_count", sa.Integer(), nullable=True))
op.add_column("steps", sa.Column("step_id", sa.String(), nullable=True))
# Drop new columns and constraints
op.drop_constraint("fk_steps_provider_id", "steps", type_="foreignkey")
op.drop_constraint("fk_steps_organization_id", "steps", type_="foreignkey")
op.drop_constraint("fk_steps_job_id", "steps", type_="foreignkey")
op.drop_column("steps", "tid")
op.drop_column("steps", "tags")
op.drop_column("steps", "completion_tokens_details")
op.drop_column("steps", "context_window_limit")
op.drop_column("steps", "model")
op.drop_column("steps", "provider_name")
op.drop_column("steps", "provider_id")
op.drop_column("steps", "organization_id")
op.drop_column("steps", "origin")
# Add constraints back
op.execute("DELETE FROM steps WHERE job_id IS NULL")
op.alter_column("steps", "job_id", nullable=False)
op.create_foreign_key("fk_job_usage_statistics_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="CASCADE")
# Change id field from string back to int
op.add_column("steps", sa.Column("old_id", sa.Integer(), nullable=True))
op.execute("""UPDATE steps SET old_id = CAST(ABS(hashtext(REPLACE(id, 'step-', '')::text)) AS integer)""")
op.drop_column("steps", "id")
op.execute("ALTER TABLE steps RENAME COLUMN old_id TO id")
op.alter_column("steps", "id", nullable=False)
op.create_primary_key("pk_steps_id", "steps", ["id"])
# Rename the table
op.rename_table("steps", "job_usage_statistics")
# ### end Alembic commands ###

View File

@@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.provider_manager import ProviderManager
from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
@@ -130,8 +132,10 @@ class Agent(BaseAgent):
# Create the persistence manager object based on the AgentState info
self.message_manager = MessageManager()
self.passage_manager = PassageManager()
self.provider_manager = ProviderManager()
self.agent_manager = AgentManager()
self.job_manager = JobManager()
self.step_manager = StepManager()
# State needed for heartbeat pausing
@@ -764,6 +768,24 @@ class Agent(BaseAgent):
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
)
# Log step - this must happen before messages are persisted
step = self.step_manager.log_step(
actor=self.user,
provider_name=self.agent_state.llm_config.model_endpoint_type,
model=self.agent_state.llm_config.model,
context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage,
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
provider_id=(
self.provider_manager.get_anthropic_override_provider_id()
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
else None
),
job_id=job_id,
)
for message in all_new_messages:
message.step_id = step.id
# Persisting into Messages
self.agent_state = self.agent_manager.append_to_in_context_messages(
all_new_messages, agent_id=self.agent_state.id, actor=self.user

View File

@@ -14,6 +14,7 @@ from letta.schemas.openai.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
)
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.services.provider_manager import ProviderManager
from letta.settings import model_settings
from letta.utils import get_utc_time, smart_urljoin
@@ -39,9 +40,6 @@ MODEL_LIST = [
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
if model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
@@ -397,6 +395,12 @@ def anthropic_chat_completions_request(
betas: List[str] = ["tools-2024-04-04"],
) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
response = anthropic_client.beta.messages.create(
**data,

View File

@@ -6,7 +6,6 @@ from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.job import Job
from letta.orm.job_messages import JobMessage
from letta.orm.job_usage_statistics import JobUsageStatistics
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
@@ -14,6 +13,7 @@ from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents
from letta.orm.step import Step
from letta.orm.tool import Tool
from letta.orm.tools_agents import ToolsAgents
from letta.orm.user import User

View File

@@ -13,8 +13,8 @@ from letta.schemas.letta_request import LettaRequestConfig
if TYPE_CHECKING:
from letta.orm.job_messages import JobMessage
from letta.orm.job_usage_statistics import JobUsageStatistics
from letta.orm.message import Message
from letta.orm.step import Step
from letta.orm.user import User
@@ -41,9 +41,7 @@ class Job(SqlalchemyBase, UserMixin):
# relationships
user: Mapped["User"] = relationship("User", back_populates="jobs")
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
usage_statistics: Mapped[list["JobUsageStatistics"]] = relationship(
"JobUsageStatistics", back_populates="job", cascade="all, delete-orphan"
)
steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update")
@property
def messages(self) -> List["Message"]:

View File

@@ -1,30 +0,0 @@
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
if TYPE_CHECKING:
from letta.orm.job import Job
class JobUsageStatistics(SqlalchemyBase):
"""Tracks usage statistics for jobs, with future support for per-step tracking."""
__tablename__ = "job_usage_statistics"
id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the usage statistics entry")
job_id: Mapped[str] = mapped_column(
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the job these statistics belong to"
)
step_id: Mapped[Optional[str]] = mapped_column(
nullable=True, doc="ID of the specific step within the job (for future per-step tracking)"
)
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
step_count: Mapped[int] = mapped_column(default=0, doc="Number of steps taken by the agent")
# Relationship back to the job
job: Mapped["Job"] = relationship("Job", back_populates="usage_statistics")

View File

@@ -1,6 +1,6 @@
from typing import Optional
from sqlalchemy import Index
from sqlalchemy import ForeignKey, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import ToolCallColumn
@@ -24,10 +24,14 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
step_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("steps.id", ondelete="SET NULL"), nullable=True, doc="ID of the step that this message belongs to"
)
# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
# Job relationship
job_message: Mapped[Optional["JobMessage"]] = relationship(

54
letta/orm/step.py Normal file
View File

@@ -0,0 +1,54 @@
import uuid
from typing import TYPE_CHECKING, Dict, List, Optional
from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.step import Step as PydanticStep
if TYPE_CHECKING:
from letta.orm.job import Job
from letta.orm.provider import Provider
class Step(SqlalchemyBase):
"""Tracks all metadata for agent step."""
__tablename__ = "steps"
__pydantic_model__ = PydanticStep
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}")
origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.")
organization_id: Mapped[str] = mapped_column(
ForeignKey("organizations.id", ondelete="RESTRICT"),
nullable=True,
doc="The unique identifier of the organization that this step ran for",
)
provider_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("providers.id", ondelete="RESTRICT"),
nullable=True,
doc="The unique identifier of the provider that was configured for this step",
)
job_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
)
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
context_window_limit: Mapped[Optional[int]] = mapped_column(
None, nullable=True, doc="The context window limit configured for this step."
)
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
# Relationships (foreign keys)
organization: Mapped[Optional["Organization"]] = relationship("Organization")
provider: Mapped[Optional["Provider"]] = relationship("Provider")
job: Mapped[Optional["Job"]] = relationship("Job", back_populates="steps")
# Relationships (backrefs)
messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")

View File

@@ -99,6 +99,7 @@ class Message(BaseMessage):
name: Optional[str] = Field(None, description="The name of the participant.")
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

31
letta/schemas/step.py Normal file
View File

@@ -0,0 +1,31 @@
from typing import Dict, List, Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.schemas.message import Message
class StepBase(LettaBase):
__id_prefix__ = "step"
class Step(StepBase):
id: str = Field(..., description="The id of the step. Assigned by the database.")
origin: Optional[str] = Field(None, description="The surface that this agent step was initiated from.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the step.")
provider_id: Optional[str] = Field(None, description="The unique identifier of the provider that was configured for this step")
job_id: Optional[str] = Field(
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
)
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
model: Optional[str] = Field(None, description="The name of the model used for this step.")
context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.")
completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.")
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
tags: List[str] = Field([], description="Metadata tags.")
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
messages: List[Message] = Field([], description="The messages generated during this step.")

View File

@@ -595,9 +595,6 @@ async def process_message_background(
)
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
# Add job usage statistics
server.job_manager.add_job_usage(job_id=job_id, usage=result.usage, actor=actor)
except Exception as e:
# Update job status to failed
job_update = JobUpdate(

View File

@@ -72,6 +72,7 @@ from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
@@ -293,6 +294,7 @@ class SyncServer(Server):
self.job_manager = JobManager()
self.agent_manager = AgentManager()
self.provider_manager = ProviderManager()
self.step_manager = StepManager()
# Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager()

View File

@@ -1,3 +1,5 @@
from functools import reduce
from operator import add
from typing import List, Literal, Optional, Union
from sqlalchemy import select
@@ -7,9 +9,9 @@ from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound
from letta.orm.job import Job as JobModel
from letta.orm.job_messages import JobMessage
from letta.orm.job_usage_statistics import JobUsageStatistics
from letta.orm.message import Message as MessageModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate
@@ -193,12 +195,7 @@ class JobManager:
self._verify_job_access(session, job_id, actor)
# Get the latest usage statistics for the job
latest_stats = (
session.query(JobUsageStatistics)
.filter(JobUsageStatistics.job_id == job_id)
.order_by(JobUsageStatistics.created_at.desc())
.first()
)
latest_stats = session.query(Step).filter(Step.job_id == job_id).order_by(Step.created_at.desc()).all()
if not latest_stats:
return LettaUsageStatistics(
@@ -209,10 +206,10 @@ class JobManager:
)
return LettaUsageStatistics(
completion_tokens=latest_stats.completion_tokens,
prompt_tokens=latest_stats.prompt_tokens,
total_tokens=latest_stats.total_tokens,
step_count=latest_stats.step_count,
completion_tokens=reduce(add, (step.completion_tokens or 0 for step in latest_stats), 0),
prompt_tokens=reduce(add, (step.prompt_tokens or 0 for step in latest_stats), 0),
total_tokens=reduce(add, (step.total_tokens or 0 for step in latest_stats), 0),
step_count=len(latest_stats),
)
@enforce_types
@@ -239,8 +236,9 @@ class JobManager:
# First verify job exists and user has access
self._verify_job_access(session, job_id, actor, access=["write"])
# Create new usage statistics entry
usage_stats = JobUsageStatistics(
# Manually log step with usage data
# TODO(@caren): log step under the hood and remove this
usage_stats = Step(
job_id=job_id,
completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens,

View File

@@ -48,9 +48,13 @@ class ProviderManager:
def delete_provider_by_id(self, provider_id: str):
"""Delete a provider."""
with self.session_maker() as session:
# Delete from provider table
provider = ProviderModel.read(db_session=session, identifier=provider_id)
provider.hard_delete(session)
# Clear api key field
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
existing_provider.api_key = None
existing_provider.update(session)
# Soft delete in provider table
existing_provider.delete(session)
session.commit()
@@ -62,9 +66,17 @@ class ProviderManager:
return [provider.to_pydantic() for provider in results]
@enforce_types
def get_anthropic_key_override(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
providers = self.list_providers(limit=1)
if len(providers) == 1 and providers[0].name == "anthropic":
return providers[0].api_key
def get_anthropic_override_provider_id(self) -> Optional[str]:
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].id
return None
@enforce_types
def get_anthropic_override_key(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].api_key
return None

View File

@@ -0,0 +1,87 @@
from typing import List, Literal, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
from letta.orm.errors import NoResultFound
from letta.orm.job import Job as JobModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class StepManager:
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def log_step(
self,
actor: PydanticUser,
provider_name: str,
model: str,
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"provider_id": provider_id,
"provider_name": provider_name,
"model": model,
"context_window_limit": context_window_limit,
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"job_id": job_id,
"tags": [],
"tid": None,
}
with self.session_maker() as session:
if job_id:
self._verify_job_access(session, job_id, actor, access=["write"])
new_step = StepModel(**step_data)
new_step.create(session)
return new_step.to_pydantic()
@enforce_types
def get_step(self, step_id: str) -> PydanticStep:
with self.session_maker() as session:
step = StepModel.read(db_session=session, identifier=step_id)
return step.to_pydantic()
def _verify_job_access(
self,
session: Session,
job_id: str,
actor: PydanticUser,
access: List[Literal["read", "write", "delete"]] = ["read"],
) -> JobModel:
"""
Verify that a job exists and the user has the required access.
Args:
session: The database session
job_id: The ID of the job to verify
actor: The user making the request
Returns:
The job if it exists and the user has access
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
job_query = select(JobModel).where(JobModel.id == job_id)
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
job = session.execute(job_query).scalar_one_or_none()
if not job:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job

View File

@@ -26,6 +26,7 @@ from letta.orm import (
Source,
SourcePassage,
SourcesAgents,
Step,
Tool,
ToolsAgents,
User,
@@ -46,6 +47,7 @@ from letta.schemas.letta_request import LettaRequestConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.passage import Passage as PydanticPassage
@@ -56,7 +58,6 @@ from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate
from letta.schemas.tool_rule import InitToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.server import SyncServer
@@ -100,6 +101,7 @@ def clear_tables(server: SyncServer):
session.execute(delete(Tool)) # Clear all records from the Tool table
session.execute(delete(Agent))
session.execute(delete(User)) # Clear all records from the user table
session.execute(delete(Step))
session.execute(delete(Provider))
session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion
@@ -2835,17 +2837,19 @@ def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser,
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
"""Test adding and retrieving job usage statistics."""
job_manager = server.job_manager
step_manager = server.step_manager
# Add usage statistics
job_manager.add_job_usage(
step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id=default_job.id,
usage=LettaUsageStatistics(
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
step_count=5,
),
step_id="step_1",
actor=default_user,
)
@@ -2874,30 +2878,33 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
"""Test adding multiple usage statistics entries for a job."""
job_manager = server.job_manager
step_manager = server.step_manager
# Add first usage statistics entry
job_manager.add_job_usage(
step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id=default_job.id,
usage=LettaUsageStatistics(
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
step_count=5,
),
step_id="step_1",
actor=default_user,
)
# Add second usage statistics entry
job_manager.add_job_usage(
step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id=default_job.id,
usage=LettaUsageStatistics(
usage=UsageStatistics(
completion_tokens=200,
prompt_tokens=100,
total_tokens=300,
step_count=10,
),
step_id="step_2",
actor=default_user,
)
@@ -2905,9 +2912,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user)
# Verify we get the most recent statistics
assert usage_stats.completion_tokens == 200
assert usage_stats.prompt_tokens == 100
assert usage_stats.total_tokens == 300
assert usage_stats.completion_tokens == 300
assert usage_stats.prompt_tokens == 150
assert usage_stats.total_tokens == 450
assert usage_stats.step_count == 2
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
@@ -2920,18 +2928,19 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
"""Test adding usage statistics for a nonexistent job."""
job_manager = server.job_manager
step_manager = server.step_manager
with pytest.raises(NoResultFound):
job_manager.add_job_usage(
step_manager.log_step(
provider_name="openai",
model="gpt-4",
context_window_limit=8192,
job_id="nonexistent_job",
usage=LettaUsageStatistics(
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
step_count=5,
),
step_id="step_1",
actor=default_user,
)

View File

@@ -1,15 +1,19 @@
import json
import os
import uuid
import warnings
from typing import List, Tuple
import pytest
from sqlalchemy import delete
import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.user import User
utils.DEBUG = True
@@ -277,6 +281,10 @@ def org_id(server):
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
server.organization_manager.delete_organization_by_id(org.id)
@@ -1098,3 +1106,72 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
request.tool_ids = [b.id for b in base_tools[:-2]]
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
assert len(agent_state.tools) == len(base_tools) - 2
def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
provider=PydanticProvider(
name="anthropic",
api_key=os.getenv("ANTHROPIC_API_KEY"),
),
actor=actor,
)
agent = server.create_agent(
request=CreateAgent(
memory_blocks=[], llm="anthropic/claude-3-opus-20240229", context_window_limit=200000, embedding="openai/text-embedding-ada-002"
),
actor=actor,
)
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
assert usage, "Sending message failed"
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
assert len(get_messages_response) > 0, "Retrieving messages failed"
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = server.step_manager.get_step(step_id=step_id)
assert step, "Step was not logged correctly"
assert step.provider_id == provider.id
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window
completion_tokens += int(step.completion_tokens)
prompt_tokens += int(step.prompt_tokens)
total_tokens += int(step.total_tokens)
assert completion_tokens == usage.completion_tokens
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens
server.provider_manager.delete_provider_by_id(provider.id)
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
assert usage, "Sending message failed"
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
assert len(get_messages_response) > 0, "Retrieving messages failed"
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = server.step_manager.get_step(step_id=step_id)
assert step, "Step was not logged correctly"
assert step.provider_id == None
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window
completion_tokens += int(step.completion_tokens)
prompt_tokens += int(step.prompt_tokens)
total_tokens += int(step.total_tokens)
assert completion_tokens == usage.completion_tokens
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens