feat: add schema/db for new steps table (#669)
This commit is contained in:
@@ -0,0 +1,116 @@
|
|||||||
|
"""Repurpose JobUsageStatistics for new Steps table
|
||||||
|
|
||||||
|
Revision ID: 416b9d2db10b
|
||||||
|
Revises: 25fc99e97839
|
||||||
|
Create Date: 2025-01-17 11:27:42.115755
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "416b9d2db10b"
|
||||||
|
down_revision: Union[str, None] = "25fc99e97839"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# Rename the table
|
||||||
|
op.rename_table("job_usage_statistics", "steps")
|
||||||
|
|
||||||
|
# Rename the foreign key constraint and drop non-null constraint
|
||||||
|
op.alter_column("steps", "job_id", nullable=True)
|
||||||
|
op.drop_constraint("fk_job_usage_statistics_job_id", "steps", type_="foreignkey")
|
||||||
|
|
||||||
|
# Change id field from int to string
|
||||||
|
op.execute("ALTER TABLE steps RENAME COLUMN id TO old_id")
|
||||||
|
op.add_column("steps", sa.Column("id", sa.String(), nullable=True))
|
||||||
|
op.execute("""UPDATE steps SET id = 'step-' || gen_random_uuid()::text""")
|
||||||
|
op.drop_column("steps", "old_id")
|
||||||
|
op.alter_column("steps", "id", nullable=False)
|
||||||
|
op.create_primary_key("pk_steps_id", "steps", ["id"])
|
||||||
|
|
||||||
|
# Add new columns
|
||||||
|
op.add_column("steps", sa.Column("origin", sa.String(), nullable=True))
|
||||||
|
op.add_column("steps", sa.Column("organization_id", sa.String(), nullable=True))
|
||||||
|
op.add_column("steps", sa.Column("provider_id", sa.String(), nullable=True))
|
||||||
|
op.add_column("steps", sa.Column("provider_name", sa.String(), nullable=True))
|
||||||
|
op.add_column("steps", sa.Column("model", sa.String(), nullable=True))
|
||||||
|
op.add_column("steps", sa.Column("context_window_limit", sa.Integer(), nullable=True))
|
||||||
|
op.add_column(
|
||||||
|
"steps",
|
||||||
|
sa.Column("completion_tokens_details", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"steps",
|
||||||
|
sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column("steps", sa.Column("tid", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Add new foreign key constraint for provider_id
|
||||||
|
op.create_foreign_key("fk_steps_organization_id", "steps", "providers", ["provider_id"], ["id"], ondelete="RESTRICT")
|
||||||
|
|
||||||
|
# Add new foreign key constraint for provider_id
|
||||||
|
op.create_foreign_key("fk_steps_provider_id", "steps", "organizations", ["organization_id"], ["id"], ondelete="RESTRICT")
|
||||||
|
|
||||||
|
# Add new foreign key constraint for provider_id
|
||||||
|
op.create_foreign_key("fk_steps_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="SET NULL")
|
||||||
|
|
||||||
|
# Drop old step_id and step_count columns which aren't in the new model
|
||||||
|
op.drop_column("steps", "step_id")
|
||||||
|
op.drop_column("steps", "step_count")
|
||||||
|
|
||||||
|
# Add step_id to messages table
|
||||||
|
op.add_column("messages", sa.Column("step_id", sa.String(), nullable=True))
|
||||||
|
op.create_foreign_key("fk_messages_step_id", "messages", "steps", ["step_id"], ["id"], ondelete="SET NULL")
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# Remove step_id from messages first to avoid foreign key conflicts
|
||||||
|
op.drop_constraint("fk_messages_step_id", "messages", type_="foreignkey")
|
||||||
|
op.drop_column("messages", "step_id")
|
||||||
|
|
||||||
|
# Restore old step_count and step_id column
|
||||||
|
op.add_column("steps", sa.Column("step_count", sa.Integer(), nullable=True))
|
||||||
|
op.add_column("steps", sa.Column("step_id", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Drop new columns and constraints
|
||||||
|
op.drop_constraint("fk_steps_provider_id", "steps", type_="foreignkey")
|
||||||
|
op.drop_constraint("fk_steps_organization_id", "steps", type_="foreignkey")
|
||||||
|
op.drop_constraint("fk_steps_job_id", "steps", type_="foreignkey")
|
||||||
|
|
||||||
|
op.drop_column("steps", "tid")
|
||||||
|
op.drop_column("steps", "tags")
|
||||||
|
op.drop_column("steps", "completion_tokens_details")
|
||||||
|
op.drop_column("steps", "context_window_limit")
|
||||||
|
op.drop_column("steps", "model")
|
||||||
|
op.drop_column("steps", "provider_name")
|
||||||
|
op.drop_column("steps", "provider_id")
|
||||||
|
op.drop_column("steps", "organization_id")
|
||||||
|
op.drop_column("steps", "origin")
|
||||||
|
|
||||||
|
# Add constraints back
|
||||||
|
op.execute("DELETE FROM steps WHERE job_id IS NULL")
|
||||||
|
op.alter_column("steps", "job_id", nullable=False)
|
||||||
|
op.create_foreign_key("fk_job_usage_statistics_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="CASCADE")
|
||||||
|
|
||||||
|
# Change id field from string back to int
|
||||||
|
op.add_column("steps", sa.Column("old_id", sa.Integer(), nullable=True))
|
||||||
|
op.execute("""UPDATE steps SET old_id = CAST(ABS(hashtext(REPLACE(id, 'step-', '')::text)) AS integer)""")
|
||||||
|
op.drop_column("steps", "id")
|
||||||
|
op.execute("ALTER TABLE steps RENAME COLUMN old_id TO id")
|
||||||
|
op.alter_column("steps", "id", nullable=False)
|
||||||
|
op.create_primary_key("pk_steps_id", "steps", ["id"])
|
||||||
|
|
||||||
|
# Rename the table
|
||||||
|
op.rename_table("steps", "job_usage_statistics")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure
|
|||||||
from letta.services.job_manager import JobManager
|
from letta.services.job_manager import JobManager
|
||||||
from letta.services.message_manager import MessageManager
|
from letta.services.message_manager import MessageManager
|
||||||
from letta.services.passage_manager import PassageManager
|
from letta.services.passage_manager import PassageManager
|
||||||
|
from letta.services.provider_manager import ProviderManager
|
||||||
|
from letta.services.step_manager import StepManager
|
||||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||||
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||||
@@ -130,8 +132,10 @@ class Agent(BaseAgent):
|
|||||||
# Create the persistence manager object based on the AgentState info
|
# Create the persistence manager object based on the AgentState info
|
||||||
self.message_manager = MessageManager()
|
self.message_manager = MessageManager()
|
||||||
self.passage_manager = PassageManager()
|
self.passage_manager = PassageManager()
|
||||||
|
self.provider_manager = ProviderManager()
|
||||||
self.agent_manager = AgentManager()
|
self.agent_manager = AgentManager()
|
||||||
self.job_manager = JobManager()
|
self.job_manager = JobManager()
|
||||||
|
self.step_manager = StepManager()
|
||||||
|
|
||||||
# State needed for heartbeat pausing
|
# State needed for heartbeat pausing
|
||||||
|
|
||||||
@@ -764,6 +768,24 @@ class Agent(BaseAgent):
|
|||||||
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
|
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log step - this must happen before messages are persisted
|
||||||
|
step = self.step_manager.log_step(
|
||||||
|
actor=self.user,
|
||||||
|
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
||||||
|
model=self.agent_state.llm_config.model,
|
||||||
|
context_window_limit=self.agent_state.llm_config.context_window,
|
||||||
|
usage=response.usage,
|
||||||
|
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
|
||||||
|
provider_id=(
|
||||||
|
self.provider_manager.get_anthropic_override_provider_id()
|
||||||
|
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
job_id=job_id,
|
||||||
|
)
|
||||||
|
for message in all_new_messages:
|
||||||
|
message.step_id = step.id
|
||||||
|
|
||||||
# Persisting into Messages
|
# Persisting into Messages
|
||||||
self.agent_state = self.agent_manager.append_to_in_context_messages(
|
self.agent_state = self.agent_manager.append_to_in_context_messages(
|
||||||
all_new_messages, agent_id=self.agent_state.id, actor=self.user
|
all_new_messages, agent_id=self.agent_state.id, actor=self.user
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
||||||
)
|
)
|
||||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||||
|
from letta.services.provider_manager import ProviderManager
|
||||||
from letta.settings import model_settings
|
from letta.settings import model_settings
|
||||||
from letta.utils import get_utc_time, smart_urljoin
|
from letta.utils import get_utc_time, smart_urljoin
|
||||||
|
|
||||||
@@ -39,9 +40,6 @@ MODEL_LIST = [
|
|||||||
|
|
||||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||||
|
|
||||||
if model_settings.anthropic_api_key:
|
|
||||||
anthropic_client = anthropic.Anthropic()
|
|
||||||
|
|
||||||
|
|
||||||
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
|
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
|
||||||
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
|
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
|
||||||
@@ -397,6 +395,12 @@ def anthropic_chat_completions_request(
|
|||||||
betas: List[str] = ["tools-2024-04-04"],
|
betas: List[str] = ["tools-2024-04-04"],
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
||||||
|
anthropic_client = None
|
||||||
|
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||||
|
if anthropic_override_key:
|
||||||
|
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||||
|
elif model_settings.anthropic_api_key:
|
||||||
|
anthropic_client = anthropic.Anthropic()
|
||||||
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
|
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
|
||||||
response = anthropic_client.beta.messages.create(
|
response = anthropic_client.beta.messages.create(
|
||||||
**data,
|
**data,
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from letta.orm.blocks_agents import BlocksAgents
|
|||||||
from letta.orm.file import FileMetadata
|
from letta.orm.file import FileMetadata
|
||||||
from letta.orm.job import Job
|
from letta.orm.job import Job
|
||||||
from letta.orm.job_messages import JobMessage
|
from letta.orm.job_messages import JobMessage
|
||||||
from letta.orm.job_usage_statistics import JobUsageStatistics
|
|
||||||
from letta.orm.message import Message
|
from letta.orm.message import Message
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||||
@@ -14,6 +13,7 @@ from letta.orm.provider import Provider
|
|||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||||
from letta.orm.source import Source
|
from letta.orm.source import Source
|
||||||
from letta.orm.sources_agents import SourcesAgents
|
from letta.orm.sources_agents import SourcesAgents
|
||||||
|
from letta.orm.step import Step
|
||||||
from letta.orm.tool import Tool
|
from letta.orm.tool import Tool
|
||||||
from letta.orm.tools_agents import ToolsAgents
|
from letta.orm.tools_agents import ToolsAgents
|
||||||
from letta.orm.user import User
|
from letta.orm.user import User
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from letta.schemas.letta_request import LettaRequestConfig
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.job_messages import JobMessage
|
from letta.orm.job_messages import JobMessage
|
||||||
from letta.orm.job_usage_statistics import JobUsageStatistics
|
|
||||||
from letta.orm.message import Message
|
from letta.orm.message import Message
|
||||||
|
from letta.orm.step import Step
|
||||||
from letta.orm.user import User
|
from letta.orm.user import User
|
||||||
|
|
||||||
|
|
||||||
@@ -41,9 +41,7 @@ class Job(SqlalchemyBase, UserMixin):
|
|||||||
# relationships
|
# relationships
|
||||||
user: Mapped["User"] = relationship("User", back_populates="jobs")
|
user: Mapped["User"] = relationship("User", back_populates="jobs")
|
||||||
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
|
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
|
||||||
usage_statistics: Mapped[list["JobUsageStatistics"]] = relationship(
|
steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update")
|
||||||
"JobUsageStatistics", back_populates="job", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self) -> List["Message"]:
|
def messages(self) -> List["Message"]:
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from letta.orm.job import Job
|
|
||||||
|
|
||||||
|
|
||||||
class JobUsageStatistics(SqlalchemyBase):
|
|
||||||
"""Tracks usage statistics for jobs, with future support for per-step tracking."""
|
|
||||||
|
|
||||||
__tablename__ = "job_usage_statistics"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the usage statistics entry")
|
|
||||||
job_id: Mapped[str] = mapped_column(
|
|
||||||
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the job these statistics belong to"
|
|
||||||
)
|
|
||||||
step_id: Mapped[Optional[str]] = mapped_column(
|
|
||||||
nullable=True, doc="ID of the specific step within the job (for future per-step tracking)"
|
|
||||||
)
|
|
||||||
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
|
|
||||||
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
|
||||||
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
|
||||||
step_count: Mapped[int] = mapped_column(default=0, doc="Number of steps taken by the agent")
|
|
||||||
|
|
||||||
# Relationship back to the job
|
|
||||||
job: Mapped["Job"] = relationship("Job", back_populates="usage_statistics")
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import Index
|
from sqlalchemy import ForeignKey, Index
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.custom_columns import ToolCallColumn
|
from letta.orm.custom_columns import ToolCallColumn
|
||||||
@@ -24,10 +24,14 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|||||||
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
||||||
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
|
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
|
||||||
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
||||||
|
step_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
ForeignKey("steps.id", ondelete="SET NULL"), nullable=True, doc="ID of the step that this message belongs to"
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
||||||
|
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
|
||||||
|
|
||||||
# Job relationship
|
# Job relationship
|
||||||
job_message: Mapped[Optional["JobMessage"]] = relationship(
|
job_message: Mapped[Optional["JobMessage"]] = relationship(
|
||||||
|
|||||||
54
letta/orm/step.py
Normal file
54
letta/orm/step.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, ForeignKey, String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
|
from letta.schemas.step import Step as PydanticStep
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from letta.orm.job import Job
|
||||||
|
from letta.orm.provider import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class Step(SqlalchemyBase):
|
||||||
|
"""Tracks all metadata for agent step."""
|
||||||
|
|
||||||
|
__tablename__ = "steps"
|
||||||
|
__pydantic_model__ = PydanticStep
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}")
|
||||||
|
origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.")
|
||||||
|
organization_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("organizations.id", ondelete="RESTRICT"),
|
||||||
|
nullable=True,
|
||||||
|
doc="The unique identifier of the organization that this step ran for",
|
||||||
|
)
|
||||||
|
provider_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
ForeignKey("providers.id", ondelete="RESTRICT"),
|
||||||
|
nullable=True,
|
||||||
|
doc="The unique identifier of the provider that was configured for this step",
|
||||||
|
)
|
||||||
|
job_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
|
||||||
|
)
|
||||||
|
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
|
||||||
|
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||||
|
context_window_limit: Mapped[Optional[int]] = mapped_column(
|
||||||
|
None, nullable=True, doc="The context window limit configured for this step."
|
||||||
|
)
|
||||||
|
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
|
||||||
|
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
||||||
|
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
||||||
|
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||||
|
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
||||||
|
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
||||||
|
|
||||||
|
# Relationships (foreign keys)
|
||||||
|
organization: Mapped[Optional["Organization"]] = relationship("Organization")
|
||||||
|
provider: Mapped[Optional["Provider"]] = relationship("Provider")
|
||||||
|
job: Mapped[Optional["Job"]] = relationship("Job", back_populates="steps")
|
||||||
|
|
||||||
|
# Relationships (backrefs)
|
||||||
|
messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")
|
||||||
@@ -99,6 +99,7 @@ class Message(BaseMessage):
|
|||||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
|
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
|
||||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||||
|
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
|
||||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||||
|
|
||||||
|
|||||||
31
letta/schemas/step.py
Normal file
31
letta/schemas/step.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from letta.schemas.letta_base import LettaBase
|
||||||
|
from letta.schemas.message import Message
|
||||||
|
|
||||||
|
|
||||||
|
class StepBase(LettaBase):
|
||||||
|
__id_prefix__ = "step"
|
||||||
|
|
||||||
|
|
||||||
|
class Step(StepBase):
|
||||||
|
id: str = Field(..., description="The id of the step. Assigned by the database.")
|
||||||
|
origin: Optional[str] = Field(None, description="The surface that this agent step was initiated from.")
|
||||||
|
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the step.")
|
||||||
|
provider_id: Optional[str] = Field(None, description="The unique identifier of the provider that was configured for this step")
|
||||||
|
job_id: Optional[str] = Field(
|
||||||
|
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
|
||||||
|
)
|
||||||
|
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
|
||||||
|
model: Optional[str] = Field(None, description="The name of the model used for this step.")
|
||||||
|
context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.")
|
||||||
|
completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.")
|
||||||
|
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
|
||||||
|
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
|
||||||
|
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
|
||||||
|
|
||||||
|
tags: List[str] = Field([], description="Metadata tags.")
|
||||||
|
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
||||||
|
messages: List[Message] = Field([], description="The messages generated during this step.")
|
||||||
@@ -595,9 +595,6 @@ async def process_message_background(
|
|||||||
)
|
)
|
||||||
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
||||||
|
|
||||||
# Add job usage statistics
|
|
||||||
server.job_manager.add_job_usage(job_id=job_id, usage=result.usage, actor=actor)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Update job status to failed
|
# Update job status to failed
|
||||||
job_update = JobUpdate(
|
job_update = JobUpdate(
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from letta.services.per_agent_lock_manager import PerAgentLockManager
|
|||||||
from letta.services.provider_manager import ProviderManager
|
from letta.services.provider_manager import ProviderManager
|
||||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||||
from letta.services.source_manager import SourceManager
|
from letta.services.source_manager import SourceManager
|
||||||
|
from letta.services.step_manager import StepManager
|
||||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
from letta.services.user_manager import UserManager
|
from letta.services.user_manager import UserManager
|
||||||
@@ -293,6 +294,7 @@ class SyncServer(Server):
|
|||||||
self.job_manager = JobManager()
|
self.job_manager = JobManager()
|
||||||
self.agent_manager = AgentManager()
|
self.agent_manager = AgentManager()
|
||||||
self.provider_manager = ProviderManager()
|
self.provider_manager = ProviderManager()
|
||||||
|
self.step_manager = StepManager()
|
||||||
|
|
||||||
# Managers that interface with parallelism
|
# Managers that interface with parallelism
|
||||||
self.per_agent_lock_manager = PerAgentLockManager()
|
self.per_agent_lock_manager = PerAgentLockManager()
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from functools import reduce
|
||||||
|
from operator import add
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -7,9 +9,9 @@ from letta.orm.enums import JobType
|
|||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.job import Job as JobModel
|
from letta.orm.job import Job as JobModel
|
||||||
from letta.orm.job_messages import JobMessage
|
from letta.orm.job_messages import JobMessage
|
||||||
from letta.orm.job_usage_statistics import JobUsageStatistics
|
|
||||||
from letta.orm.message import Message as MessageModel
|
from letta.orm.message import Message as MessageModel
|
||||||
from letta.orm.sqlalchemy_base import AccessType
|
from letta.orm.sqlalchemy_base import AccessType
|
||||||
|
from letta.orm.step import Step
|
||||||
from letta.schemas.enums import JobStatus, MessageRole
|
from letta.schemas.enums import JobStatus, MessageRole
|
||||||
from letta.schemas.job import Job as PydanticJob
|
from letta.schemas.job import Job as PydanticJob
|
||||||
from letta.schemas.job import JobUpdate
|
from letta.schemas.job import JobUpdate
|
||||||
@@ -193,12 +195,7 @@ class JobManager:
|
|||||||
self._verify_job_access(session, job_id, actor)
|
self._verify_job_access(session, job_id, actor)
|
||||||
|
|
||||||
# Get the latest usage statistics for the job
|
# Get the latest usage statistics for the job
|
||||||
latest_stats = (
|
latest_stats = session.query(Step).filter(Step.job_id == job_id).order_by(Step.created_at.desc()).all()
|
||||||
session.query(JobUsageStatistics)
|
|
||||||
.filter(JobUsageStatistics.job_id == job_id)
|
|
||||||
.order_by(JobUsageStatistics.created_at.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not latest_stats:
|
if not latest_stats:
|
||||||
return LettaUsageStatistics(
|
return LettaUsageStatistics(
|
||||||
@@ -209,10 +206,10 @@ class JobManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return LettaUsageStatistics(
|
return LettaUsageStatistics(
|
||||||
completion_tokens=latest_stats.completion_tokens,
|
completion_tokens=reduce(add, (step.completion_tokens or 0 for step in latest_stats), 0),
|
||||||
prompt_tokens=latest_stats.prompt_tokens,
|
prompt_tokens=reduce(add, (step.prompt_tokens or 0 for step in latest_stats), 0),
|
||||||
total_tokens=latest_stats.total_tokens,
|
total_tokens=reduce(add, (step.total_tokens or 0 for step in latest_stats), 0),
|
||||||
step_count=latest_stats.step_count,
|
step_count=len(latest_stats),
|
||||||
)
|
)
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@@ -239,8 +236,9 @@ class JobManager:
|
|||||||
# First verify job exists and user has access
|
# First verify job exists and user has access
|
||||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||||
|
|
||||||
# Create new usage statistics entry
|
# Manually log step with usage data
|
||||||
usage_stats = JobUsageStatistics(
|
# TODO(@caren): log step under the hood and remove this
|
||||||
|
usage_stats = Step(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
completion_tokens=usage.completion_tokens,
|
completion_tokens=usage.completion_tokens,
|
||||||
prompt_tokens=usage.prompt_tokens,
|
prompt_tokens=usage.prompt_tokens,
|
||||||
|
|||||||
@@ -48,9 +48,13 @@ class ProviderManager:
|
|||||||
def delete_provider_by_id(self, provider_id: str):
|
def delete_provider_by_id(self, provider_id: str):
|
||||||
"""Delete a provider."""
|
"""Delete a provider."""
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
# Delete from provider table
|
# Clear api key field
|
||||||
provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||||
provider.hard_delete(session)
|
existing_provider.api_key = None
|
||||||
|
existing_provider.update(session)
|
||||||
|
|
||||||
|
# Soft delete in provider table
|
||||||
|
existing_provider.delete(session)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@@ -62,9 +66,17 @@ class ProviderManager:
|
|||||||
return [provider.to_pydantic() for provider in results]
|
return [provider.to_pydantic() for provider in results]
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def get_anthropic_key_override(self) -> Optional[str]:
|
def get_anthropic_override_provider_id(self) -> Optional[str]:
|
||||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
|
||||||
providers = self.list_providers(limit=1)
|
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||||
if len(providers) == 1 and providers[0].name == "anthropic":
|
if len(anthropic_provider) != 0:
|
||||||
return providers[0].api_key
|
return anthropic_provider[0].id
|
||||||
|
return None
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_anthropic_override_key(self) -> Optional[str]:
|
||||||
|
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||||
|
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||||
|
if len(anthropic_provider) != 0:
|
||||||
|
return anthropic_provider[0].api_key
|
||||||
return None
|
return None
|
||||||
|
|||||||
87
letta/services/step_manager.py
Normal file
87
letta/services/step_manager.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from letta.orm.errors import NoResultFound
|
||||||
|
from letta.orm.job import Job as JobModel
|
||||||
|
from letta.orm.sqlalchemy_base import AccessType
|
||||||
|
from letta.orm.step import Step as StepModel
|
||||||
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
|
from letta.schemas.step import Step as PydanticStep
|
||||||
|
from letta.schemas.user import User as PydanticUser
|
||||||
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
|
class StepManager:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from letta.server.server import db_context
|
||||||
|
|
||||||
|
self.session_maker = db_context
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def log_step(
|
||||||
|
self,
|
||||||
|
actor: PydanticUser,
|
||||||
|
provider_name: str,
|
||||||
|
model: str,
|
||||||
|
context_window_limit: int,
|
||||||
|
usage: UsageStatistics,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
job_id: Optional[str] = None,
|
||||||
|
) -> PydanticStep:
|
||||||
|
step_data = {
|
||||||
|
"origin": None,
|
||||||
|
"organization_id": actor.organization_id,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_name": provider_name,
|
||||||
|
"model": model,
|
||||||
|
"context_window_limit": context_window_limit,
|
||||||
|
"completion_tokens": usage.completion_tokens,
|
||||||
|
"prompt_tokens": usage.prompt_tokens,
|
||||||
|
"total_tokens": usage.total_tokens,
|
||||||
|
"job_id": job_id,
|
||||||
|
"tags": [],
|
||||||
|
"tid": None,
|
||||||
|
}
|
||||||
|
with self.session_maker() as session:
|
||||||
|
if job_id:
|
||||||
|
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||||
|
new_step = StepModel(**step_data)
|
||||||
|
new_step.create(session)
|
||||||
|
return new_step.to_pydantic()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_step(self, step_id: str) -> PydanticStep:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
step = StepModel.read(db_session=session, identifier=step_id)
|
||||||
|
return step.to_pydantic()
|
||||||
|
|
||||||
|
def _verify_job_access(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
job_id: str,
|
||||||
|
actor: PydanticUser,
|
||||||
|
access: List[Literal["read", "write", "delete"]] = ["read"],
|
||||||
|
) -> JobModel:
|
||||||
|
"""
|
||||||
|
Verify that a job exists and the user has the required access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: The database session
|
||||||
|
job_id: The ID of the job to verify
|
||||||
|
actor: The user making the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The job if it exists and the user has access
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoResultFound: If the job does not exist or user does not have access
|
||||||
|
"""
|
||||||
|
job_query = select(JobModel).where(JobModel.id == job_id)
|
||||||
|
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
|
||||||
|
job = session.execute(job_query).scalar_one_or_none()
|
||||||
|
if not job:
|
||||||
|
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||||
|
return job
|
||||||
@@ -26,6 +26,7 @@ from letta.orm import (
|
|||||||
Source,
|
Source,
|
||||||
SourcePassage,
|
SourcePassage,
|
||||||
SourcesAgents,
|
SourcesAgents,
|
||||||
|
Step,
|
||||||
Tool,
|
Tool,
|
||||||
ToolsAgents,
|
ToolsAgents,
|
||||||
User,
|
User,
|
||||||
@@ -46,6 +47,7 @@ from letta.schemas.letta_request import LettaRequestConfig
|
|||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||||
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||||
from letta.schemas.organization import Organization as PydanticOrganization
|
from letta.schemas.organization import Organization as PydanticOrganization
|
||||||
from letta.schemas.passage import Passage as PydanticPassage
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
@@ -56,7 +58,6 @@ from letta.schemas.source import SourceUpdate
|
|||||||
from letta.schemas.tool import Tool as PydanticTool
|
from letta.schemas.tool import Tool as PydanticTool
|
||||||
from letta.schemas.tool import ToolUpdate
|
from letta.schemas.tool import ToolUpdate
|
||||||
from letta.schemas.tool_rule import InitToolRule
|
from letta.schemas.tool_rule import InitToolRule
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.schemas.user import UserUpdate
|
from letta.schemas.user import UserUpdate
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
@@ -100,6 +101,7 @@ def clear_tables(server: SyncServer):
|
|||||||
session.execute(delete(Tool)) # Clear all records from the Tool table
|
session.execute(delete(Tool)) # Clear all records from the Tool table
|
||||||
session.execute(delete(Agent))
|
session.execute(delete(Agent))
|
||||||
session.execute(delete(User)) # Clear all records from the user table
|
session.execute(delete(User)) # Clear all records from the user table
|
||||||
|
session.execute(delete(Step))
|
||||||
session.execute(delete(Provider))
|
session.execute(delete(Provider))
|
||||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||||
session.commit() # Commit the deletion
|
session.commit() # Commit the deletion
|
||||||
@@ -2835,17 +2837,19 @@ def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser,
|
|||||||
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
|
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
|
||||||
"""Test adding and retrieving job usage statistics."""
|
"""Test adding and retrieving job usage statistics."""
|
||||||
job_manager = server.job_manager
|
job_manager = server.job_manager
|
||||||
|
step_manager = server.step_manager
|
||||||
|
|
||||||
# Add usage statistics
|
# Add usage statistics
|
||||||
job_manager.add_job_usage(
|
step_manager.log_step(
|
||||||
|
provider_name="openai",
|
||||||
|
model="gpt-4",
|
||||||
|
context_window_limit=8192,
|
||||||
job_id=default_job.id,
|
job_id=default_job.id,
|
||||||
usage=LettaUsageStatistics(
|
usage=UsageStatistics(
|
||||||
completion_tokens=100,
|
completion_tokens=100,
|
||||||
prompt_tokens=50,
|
prompt_tokens=50,
|
||||||
total_tokens=150,
|
total_tokens=150,
|
||||||
step_count=5,
|
|
||||||
),
|
),
|
||||||
step_id="step_1",
|
|
||||||
actor=default_user,
|
actor=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2874,30 +2878,33 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
|
|||||||
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
|
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
|
||||||
"""Test adding multiple usage statistics entries for a job."""
|
"""Test adding multiple usage statistics entries for a job."""
|
||||||
job_manager = server.job_manager
|
job_manager = server.job_manager
|
||||||
|
step_manager = server.step_manager
|
||||||
|
|
||||||
# Add first usage statistics entry
|
# Add first usage statistics entry
|
||||||
job_manager.add_job_usage(
|
step_manager.log_step(
|
||||||
|
provider_name="openai",
|
||||||
|
model="gpt-4",
|
||||||
|
context_window_limit=8192,
|
||||||
job_id=default_job.id,
|
job_id=default_job.id,
|
||||||
usage=LettaUsageStatistics(
|
usage=UsageStatistics(
|
||||||
completion_tokens=100,
|
completion_tokens=100,
|
||||||
prompt_tokens=50,
|
prompt_tokens=50,
|
||||||
total_tokens=150,
|
total_tokens=150,
|
||||||
step_count=5,
|
|
||||||
),
|
),
|
||||||
step_id="step_1",
|
|
||||||
actor=default_user,
|
actor=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add second usage statistics entry
|
# Add second usage statistics entry
|
||||||
job_manager.add_job_usage(
|
step_manager.log_step(
|
||||||
|
provider_name="openai",
|
||||||
|
model="gpt-4",
|
||||||
|
context_window_limit=8192,
|
||||||
job_id=default_job.id,
|
job_id=default_job.id,
|
||||||
usage=LettaUsageStatistics(
|
usage=UsageStatistics(
|
||||||
completion_tokens=200,
|
completion_tokens=200,
|
||||||
prompt_tokens=100,
|
prompt_tokens=100,
|
||||||
total_tokens=300,
|
total_tokens=300,
|
||||||
step_count=10,
|
|
||||||
),
|
),
|
||||||
step_id="step_2",
|
|
||||||
actor=default_user,
|
actor=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2905,9 +2912,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
|||||||
usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user)
|
usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user)
|
||||||
|
|
||||||
# Verify we get the most recent statistics
|
# Verify we get the most recent statistics
|
||||||
assert usage_stats.completion_tokens == 200
|
assert usage_stats.completion_tokens == 300
|
||||||
assert usage_stats.prompt_tokens == 100
|
assert usage_stats.prompt_tokens == 150
|
||||||
assert usage_stats.total_tokens == 300
|
assert usage_stats.total_tokens == 450
|
||||||
|
assert usage_stats.step_count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||||
@@ -2920,18 +2928,19 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
|||||||
|
|
||||||
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
|
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
|
||||||
"""Test adding usage statistics for a nonexistent job."""
|
"""Test adding usage statistics for a nonexistent job."""
|
||||||
job_manager = server.job_manager
|
step_manager = server.step_manager
|
||||||
|
|
||||||
with pytest.raises(NoResultFound):
|
with pytest.raises(NoResultFound):
|
||||||
job_manager.add_job_usage(
|
step_manager.log_step(
|
||||||
|
provider_name="openai",
|
||||||
|
model="gpt-4",
|
||||||
|
context_window_limit=8192,
|
||||||
job_id="nonexistent_job",
|
job_id="nonexistent_job",
|
||||||
usage=LettaUsageStatistics(
|
usage=UsageStatistics(
|
||||||
completion_tokens=100,
|
completion_tokens=100,
|
||||||
prompt_tokens=50,
|
prompt_tokens=50,
|
||||||
total_tokens=150,
|
total_tokens=150,
|
||||||
step_count=5,
|
|
||||||
),
|
),
|
||||||
step_id="step_1",
|
|
||||||
actor=default_user,
|
actor=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import delete
|
||||||
|
|
||||||
import letta.utils as utils
|
import letta.utils as utils
|
||||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||||
|
from letta.orm import Provider, Step
|
||||||
from letta.schemas.block import CreateBlock
|
from letta.schemas.block import CreateBlock
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||||
|
from letta.schemas.providers import Provider as PydanticProvider
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
|
|
||||||
utils.DEBUG = True
|
utils.DEBUG = True
|
||||||
@@ -277,6 +281,10 @@ def org_id(server):
|
|||||||
yield org.id
|
yield org.id
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
|
with server.organization_manager.session_maker() as session:
|
||||||
|
session.execute(delete(Step))
|
||||||
|
session.execute(delete(Provider))
|
||||||
|
session.commit()
|
||||||
server.organization_manager.delete_organization_by_id(org.id)
|
server.organization_manager.delete_organization_by_id(org.id)
|
||||||
|
|
||||||
|
|
||||||
@@ -1098,3 +1106,72 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
|||||||
request.tool_ids = [b.id for b in base_tools[:-2]]
|
request.tool_ids = [b.id for b in base_tools[:-2]]
|
||||||
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
|
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
|
||||||
assert len(agent_state.tools) == len(base_tools) - 2
|
assert len(agent_state.tools) == len(base_tools) - 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||||
|
actor = server.user_manager.get_user_or_default(user_id)
|
||||||
|
provider = server.provider_manager.create_provider(
|
||||||
|
provider=PydanticProvider(
|
||||||
|
name="anthropic",
|
||||||
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
),
|
||||||
|
actor=actor,
|
||||||
|
)
|
||||||
|
agent = server.create_agent(
|
||||||
|
request=CreateAgent(
|
||||||
|
memory_blocks=[], llm="anthropic/claude-3-opus-20240229", context_window_limit=200000, embedding="openai/text-embedding-ada-002"
|
||||||
|
),
|
||||||
|
actor=actor,
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
|
||||||
|
|
||||||
|
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
|
||||||
|
assert usage, "Sending message failed"
|
||||||
|
|
||||||
|
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
|
||||||
|
assert len(get_messages_response) > 0, "Retrieving messages failed"
|
||||||
|
|
||||||
|
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||||
|
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||||
|
for step_id in step_ids:
|
||||||
|
step = server.step_manager.get_step(step_id=step_id)
|
||||||
|
assert step, "Step was not logged correctly"
|
||||||
|
assert step.provider_id == provider.id
|
||||||
|
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||||
|
assert step.model == agent.llm_config.model
|
||||||
|
assert step.context_window_limit == agent.llm_config.context_window
|
||||||
|
completion_tokens += int(step.completion_tokens)
|
||||||
|
prompt_tokens += int(step.prompt_tokens)
|
||||||
|
total_tokens += int(step.total_tokens)
|
||||||
|
|
||||||
|
assert completion_tokens == usage.completion_tokens
|
||||||
|
assert prompt_tokens == usage.prompt_tokens
|
||||||
|
assert total_tokens == usage.total_tokens
|
||||||
|
|
||||||
|
server.provider_manager.delete_provider_by_id(provider.id)
|
||||||
|
|
||||||
|
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
|
||||||
|
|
||||||
|
usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message")
|
||||||
|
assert usage, "Sending message failed"
|
||||||
|
|
||||||
|
get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id)
|
||||||
|
assert len(get_messages_response) > 0, "Retrieving messages failed"
|
||||||
|
|
||||||
|
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||||
|
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||||
|
for step_id in step_ids:
|
||||||
|
step = server.step_manager.get_step(step_id=step_id)
|
||||||
|
assert step, "Step was not logged correctly"
|
||||||
|
assert step.provider_id == None
|
||||||
|
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||||
|
assert step.model == agent.llm_config.model
|
||||||
|
assert step.context_window_limit == agent.llm_config.context_window
|
||||||
|
completion_tokens += int(step.completion_tokens)
|
||||||
|
prompt_tokens += int(step.prompt_tokens)
|
||||||
|
total_tokens += int(step.total_tokens)
|
||||||
|
|
||||||
|
assert completion_tokens == usage.completion_tokens
|
||||||
|
assert prompt_tokens == usage.prompt_tokens
|
||||||
|
assert total_tokens == usage.total_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user