From ef6fce8e0f753244701cfdf41e595edef83f92d5 Mon Sep 17 00:00:00 2001 From: cthomas Date: Sat, 18 Jan 2025 12:20:10 -0800 Subject: [PATCH] feat: add schema/db for new steps table (#669) --- ...b_repurpose_jobusagestatistics_for_new_.py | 116 ++++++++++++++++++ letta/agent.py | 22 ++++ letta/llm_api/anthropic.py | 10 +- letta/orm/__init__.py | 2 +- letta/orm/job.py | 6 +- letta/orm/job_usage_statistics.py | 30 ----- letta/orm/message.py | 6 +- letta/orm/step.py | 54 ++++++++ letta/schemas/message.py | 1 + letta/schemas/step.py | 31 +++++ letta/server/rest_api/routers/v1/agents.py | 3 - letta/server/server.py | 2 + letta/services/job_manager.py | 24 ++-- letta/services/provider_manager.py | 28 +++-- letta/services/step_manager.py | 87 +++++++++++++ tests/test_managers.py | 51 ++++---- tests/test_server.py | 77 ++++++++++++ 17 files changed, 466 insertions(+), 84 deletions(-) create mode 100644 alembic/versions/416b9d2db10b_repurpose_jobusagestatistics_for_new_.py delete mode 100644 letta/orm/job_usage_statistics.py create mode 100644 letta/orm/step.py create mode 100644 letta/schemas/step.py create mode 100644 letta/services/step_manager.py diff --git a/alembic/versions/416b9d2db10b_repurpose_jobusagestatistics_for_new_.py b/alembic/versions/416b9d2db10b_repurpose_jobusagestatistics_for_new_.py new file mode 100644 index 00000000..332fdb3f --- /dev/null +++ b/alembic/versions/416b9d2db10b_repurpose_jobusagestatistics_for_new_.py @@ -0,0 +1,116 @@ +"""Repurpose JobUsageStatistics for new Steps table + +Revision ID: 416b9d2db10b +Revises: 25fc99e97839 +Create Date: 2025-01-17 11:27:42.115755 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "416b9d2db10b" +down_revision: Union[str, None] = "25fc99e97839" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Rename the table + op.rename_table("job_usage_statistics", "steps") + + # Rename the foreign key constraint and drop non-null constraint + op.alter_column("steps", "job_id", nullable=True) + op.drop_constraint("fk_job_usage_statistics_job_id", "steps", type_="foreignkey") + + # Change id field from int to string + op.execute("ALTER TABLE steps RENAME COLUMN id TO old_id") + op.add_column("steps", sa.Column("id", sa.String(), nullable=True)) + op.execute("""UPDATE steps SET id = 'step-' || gen_random_uuid()::text""") + op.drop_column("steps", "old_id") + op.alter_column("steps", "id", nullable=False) + op.create_primary_key("pk_steps_id", "steps", ["id"]) + + # Add new columns + op.add_column("steps", sa.Column("origin", sa.String(), nullable=True)) + op.add_column("steps", sa.Column("organization_id", sa.String(), nullable=True)) + op.add_column("steps", sa.Column("provider_id", sa.String(), nullable=True)) + op.add_column("steps", sa.Column("provider_name", sa.String(), nullable=True)) + op.add_column("steps", sa.Column("model", sa.String(), nullable=True)) + op.add_column("steps", sa.Column("context_window_limit", sa.Integer(), nullable=True)) + op.add_column( + "steps", + sa.Column("completion_tokens_details", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + ) + op.add_column( + "steps", + sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + ) + op.add_column("steps", sa.Column("tid", sa.String(), nullable=True)) + + # Add new foreign key constraint for provider_id + op.create_foreign_key("fk_steps_organization_id", "steps", "providers", ["provider_id"], ["id"], ondelete="RESTRICT") + + # Add new foreign key constraint for provider_id + op.create_foreign_key("fk_steps_provider_id", "steps", "organizations", ["organization_id"], ["id"], ondelete="RESTRICT") + + # Add new foreign key constraint for provider_id + op.create_foreign_key("fk_steps_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="SET NULL") + + # Drop old step_id and step_count columns which aren't in the new model + op.drop_column("steps", "step_id") + op.drop_column("steps", "step_count") + + # Add step_id to messages table + op.add_column("messages", sa.Column("step_id", sa.String(), nullable=True)) + op.create_foreign_key("fk_messages_step_id", "messages", "steps", ["step_id"], ["id"], ondelete="SET NULL") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Remove step_id from messages first to avoid foreign key conflicts + op.drop_constraint("fk_messages_step_id", "messages", type_="foreignkey") + op.drop_column("messages", "step_id") + + # Restore old step_count and step_id column + op.add_column("steps", sa.Column("step_count", sa.Integer(), nullable=True)) + op.add_column("steps", sa.Column("step_id", sa.String(), nullable=True)) + + # Drop new columns and constraints + op.drop_constraint("fk_steps_provider_id", "steps", type_="foreignkey") + op.drop_constraint("fk_steps_organization_id", "steps", type_="foreignkey") + op.drop_constraint("fk_steps_job_id", "steps", type_="foreignkey") + + op.drop_column("steps", "tid") + op.drop_column("steps", "tags") + op.drop_column("steps", "completion_tokens_details") + op.drop_column("steps", "context_window_limit") + op.drop_column("steps", "model") + op.drop_column("steps", "provider_name") + op.drop_column("steps", "provider_id") + op.drop_column("steps", "organization_id") + op.drop_column("steps", "origin") + + # Add constraints back + op.execute("DELETE FROM steps WHERE job_id IS NULL") + op.alter_column("steps", "job_id", nullable=False) + op.create_foreign_key("fk_job_usage_statistics_job_id", "steps", "jobs", ["job_id"], ["id"], ondelete="CASCADE") + + # Change id field from string back to int + op.add_column("steps", sa.Column("old_id", sa.Integer(), nullable=True)) + op.execute("""UPDATE steps SET old_id = CAST(ABS(hashtext(REPLACE(id, 'step-', '')::text)) AS integer)""") + op.drop_column("steps", "id") + op.execute("ALTER TABLE steps RENAME COLUMN old_id TO id") + op.alter_column("steps", "id", nullable=False) + op.create_primary_key("pk_steps_id", "steps", ["id"]) + + # Rename the table + op.rename_table("steps", "job_usage_statistics") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 29575ad5..978fbda3 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.provider_manager import ProviderManager +from letta.services.step_manager import StepManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message @@ -130,8 +132,10 @@ class Agent(BaseAgent): # Create the persistence manager object based on the AgentState info self.message_manager = MessageManager() self.passage_manager = PassageManager() + self.provider_manager = ProviderManager() self.agent_manager = AgentManager() self.job_manager = JobManager() + self.step_manager = StepManager() # State needed for heartbeat pausing @@ -764,6 +768,24 @@ class Agent(BaseAgent): f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) + # Log step - this must happen before messages are persisted + step = self.step_manager.log_step( + actor=self.user, + provider_name=self.agent_state.llm_config.model_endpoint_type, + model=self.agent_state.llm_config.model, + context_window_limit=self.agent_state.llm_config.context_window, + usage=response.usage, + # TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature + provider_id=( + self.provider_manager.get_anthropic_override_provider_id() + if self.agent_state.llm_config.model_endpoint_type == "anthropic" + else None + ), + job_id=job_id, + ) + for message in all_new_messages: + message.step_id = step.id + # Persisting into Messages self.agent_state = self.agent_manager.append_to_in_context_messages( all_new_messages, agent_id=self.agent_state.id, actor=self.user diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 87adfc5a..b562d466 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -14,6 +14,7 @@ from letta.schemas.openai.chat_completion_response import ( Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype ) from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics +from letta.services.provider_manager import ProviderManager from letta.settings import model_settings from letta.utils import get_utc_time, smart_urljoin @@ -39,9 +40,6 @@ MODEL_LIST = [ DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence." -if model_settings.anthropic_api_key: - anthropic_client = anthropic.Anthropic() - def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int: for model_dict in anthropic_get_model_list(url=url, api_key=api_key): @@ -397,6 +395,12 @@ def anthropic_chat_completions_request( betas: List[str] = ["tools-2024-04-04"], ) -> ChatCompletionResponse: """https://docs.anthropic.com/claude/docs/tool-use""" + anthropic_client = None + anthropic_override_key = ProviderManager().get_anthropic_override_key() + if anthropic_override_key: + anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + elif model_settings.anthropic_api_key: + anthropic_client = anthropic.Anthropic() data = _prepare_anthropic_request(data, inner_thoughts_xml_tag) response = anthropic_client.beta.messages.create( **data, diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 74cfe0c4..5898dd80 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -6,7 +6,6 @@ from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata from letta.orm.job import Job from letta.orm.job_messages import JobMessage -from letta.orm.job_usage_statistics import JobUsageStatistics from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.passage import AgentPassage, BasePassage, SourcePassage @@ -14,6 +13,7 @@ from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents +from letta.orm.step import Step from letta.orm.tool import Tool from letta.orm.tools_agents import ToolsAgents from letta.orm.user import User diff --git a/letta/orm/job.py b/letta/orm/job.py index 95e67006..aacb6785 100644 --- a/letta/orm/job.py +++ b/letta/orm/job.py @@ -13,8 +13,8 @@ from letta.schemas.letta_request import LettaRequestConfig if TYPE_CHECKING: from letta.orm.job_messages import JobMessage - from letta.orm.job_usage_statistics import JobUsageStatistics from letta.orm.message import Message + from letta.orm.step import Step from letta.orm.user import User @@ -41,9 +41,7 @@ class Job(SqlalchemyBase, UserMixin): # relationships user: Mapped["User"] = relationship("User", back_populates="jobs") job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan") - usage_statistics: Mapped[list["JobUsageStatistics"]] = relationship( - "JobUsageStatistics", back_populates="job", cascade="all, delete-orphan" - ) + steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update") @property def messages(self) -> List["Message"]: diff --git a/letta/orm/job_usage_statistics.py b/letta/orm/job_usage_statistics.py deleted file mode 100644 index 0a355d69..00000000 --- a/letta/orm/job_usage_statistics.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -from sqlalchemy import ForeignKey -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from letta.orm.sqlalchemy_base import SqlalchemyBase - -if TYPE_CHECKING: - from letta.orm.job import Job - - -class JobUsageStatistics(SqlalchemyBase): - """Tracks usage statistics for jobs, with future support for per-step tracking.""" - - __tablename__ = "job_usage_statistics" - - id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the usage statistics entry") - job_id: Mapped[str] = mapped_column( - ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the job these statistics belong to" - ) - step_id: Mapped[Optional[str]] = mapped_column( - nullable=True, doc="ID of the specific step within the job (for future per-step tracking)" - ) - completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent") - prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt") - total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent") - step_count: Mapped[int] = mapped_column(default=0, doc="Number of steps taken by the agent") - - # Relationship back to the job - job: Mapped["Job"] = relationship("Job", back_populates="usage_statistics") diff --git a/letta/orm/message.py b/letta/orm/message.py index 231462a4..e06fae3d 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,6 +1,6 @@ from typing import Optional -from sqlalchemy import Index +from sqlalchemy import ForeignKey, Index from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.custom_columns import ToolCallColumn @@ -24,10 +24,14 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios") tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information") tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call") + step_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("steps.id", ondelete="SET NULL"), nullable=True, doc="ID of the step that this message belongs to" + ) # Relationships agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin") + step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin") # Job relationship job_message: Mapped[Optional["JobMessage"]] = relationship( diff --git a/letta/orm/step.py b/letta/orm/step.py new file mode 100644 index 00000000..8ea5f313 --- /dev/null +++ b/letta/orm/step.py @@ -0,0 +1,54 @@ +import uuid +from typing import TYPE_CHECKING, Dict, List, Optional + +from sqlalchemy import JSON, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.step import Step as PydanticStep + +if TYPE_CHECKING: + from letta.orm.job import Job + from letta.orm.provider import Provider + + +class Step(SqlalchemyBase): + """Tracks all metadata for agent step.""" + + __tablename__ = "steps" + __pydantic_model__ = PydanticStep + + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}") + origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.") + organization_id: Mapped[str] = mapped_column( + ForeignKey("organizations.id", ondelete="RESTRICT"), + nullable=True, + doc="The unique identifier of the organization that this step ran for", + ) + provider_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("providers.id", ondelete="RESTRICT"), + nullable=True, + doc="The unique identifier of the provider that was configured for this step", + ) + job_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step" + ) + provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.") + model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") + context_window_limit: Mapped[Optional[int]] = mapped_column( + None, nullable=True, doc="The context window limit configured for this step." + ) + completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent") + prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt") + total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent") + completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.") + tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.") + tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.") + + # Relationships (foreign keys) + organization: Mapped[Optional["Organization"]] = relationship("Organization") + provider: Mapped[Optional["Provider"]] = relationship("Provider") + job: Mapped[Optional["Job"]] = relationship("Job", back_populates="steps") + + # Relationships (backrefs) + messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index df09aa25..41b85259 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -99,6 +99,7 @@ class Message(BaseMessage): name: Optional[str] = Field(None, description="The name of the participant.") tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.") tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") + step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/schemas/step.py b/letta/schemas/step.py new file mode 100644 index 00000000..c3482878 --- /dev/null +++ b/letta/schemas/step.py @@ -0,0 +1,31 @@ +from typing import Dict, List, Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase +from letta.schemas.message import Message + + +class StepBase(LettaBase): + __id_prefix__ = "step" + + +class Step(StepBase): + id: str = Field(..., description="The id of the step. Assigned by the database.") + origin: Optional[str] = Field(None, description="The surface that this agent step was initiated from.") + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the step.") + provider_id: Optional[str] = Field(None, description="The unique identifier of the provider that was configured for this step") + job_id: Optional[str] = Field( + None, description="The unique identifier of the job that this step belongs to. Only included for async calls." + ) + provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.") + model: Optional[str] = Field(None, description="The name of the model used for this step.") + context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.") + completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.") + prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.") + total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.") + completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.") + + tags: List[str] = Field([], description="Metadata tags.") + tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.") + messages: List[Message] = Field([], description="The messages generated during this step.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index d062a54a..8dccb9e6 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -595,9 +595,6 @@ async def process_message_background( ) server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor) - # Add job usage statistics - server.job_manager.add_job_usage(job_id=job_id, usage=result.usage, actor=actor) - except Exception as e: # Update job status to failed job_update = JobUpdate( diff --git a/letta/server/server.py b/letta/server/server.py index da4af5fe..352b982e 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -72,6 +72,7 @@ from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.services.provider_manager import ProviderManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager +from letta.services.step_manager import StepManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager @@ -293,6 +294,7 @@ class SyncServer(Server): self.job_manager = JobManager() self.agent_manager = AgentManager() self.provider_manager = ProviderManager() + self.step_manager = StepManager() # Managers that interface with parallelism self.per_agent_lock_manager = PerAgentLockManager() diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index b8ea803b..f014c568 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -1,3 +1,5 @@ +from functools import reduce +from operator import add from typing import List, Literal, Optional, Union from sqlalchemy import select @@ -7,9 +9,9 @@ from letta.orm.enums import JobType from letta.orm.errors import NoResultFound from letta.orm.job import Job as JobModel from letta.orm.job_messages import JobMessage -from letta.orm.job_usage_statistics import JobUsageStatistics from letta.orm.message import Message as MessageModel from letta.orm.sqlalchemy_base import AccessType +from letta.orm.step import Step from letta.schemas.enums import JobStatus, MessageRole from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate @@ -193,12 +195,7 @@ class JobManager: self._verify_job_access(session, job_id, actor) # Get the latest usage statistics for the job - latest_stats = ( - session.query(JobUsageStatistics) - .filter(JobUsageStatistics.job_id == job_id) - .order_by(JobUsageStatistics.created_at.desc()) - .first() - ) + latest_stats = session.query(Step).filter(Step.job_id == job_id).order_by(Step.created_at.desc()).all() if not latest_stats: return LettaUsageStatistics( @@ -209,10 +206,10 @@ class JobManager: ) return LettaUsageStatistics( - completion_tokens=latest_stats.completion_tokens, - prompt_tokens=latest_stats.prompt_tokens, - total_tokens=latest_stats.total_tokens, - step_count=latest_stats.step_count, + completion_tokens=reduce(add, (step.completion_tokens or 0 for step in latest_stats), 0), + prompt_tokens=reduce(add, (step.prompt_tokens or 0 for step in latest_stats), 0), + total_tokens=reduce(add, (step.total_tokens or 0 for step in latest_stats), 0), + step_count=len(latest_stats), ) @enforce_types @@ -239,8 +236,9 @@ class JobManager: # First verify job exists and user has access self._verify_job_access(session, job_id, actor, access=["write"]) - # Create new usage statistics entry - usage_stats = JobUsageStatistics( + # Manually log step with usage data + # TODO(@caren): log step under the hood and remove this + usage_stats = Step( job_id=job_id, completion_tokens=usage.completion_tokens, prompt_tokens=usage.prompt_tokens, diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 6b06bf79..989e7eb7 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -48,9 +48,13 @@ class ProviderManager: def delete_provider_by_id(self, provider_id: str): """Delete a provider.""" with self.session_maker() as session: - # Delete from provider table - provider = ProviderModel.read(db_session=session, identifier=provider_id) - provider.hard_delete(session) + # Clear api key field + existing_provider = ProviderModel.read(db_session=session, identifier=provider_id) + existing_provider.api_key = None + existing_provider.update(session) + + # Soft delete in provider table + existing_provider.delete(session) session.commit() @@ -62,9 +66,17 @@ class ProviderManager: return [provider.to_pydantic() for provider in results] @enforce_types - def get_anthropic_key_override(self) -> Optional[str]: - """Helper function to fetch custom anthropic key for v0 BYOK feature""" - providers = self.list_providers(limit=1) - if len(providers) == 1 and providers[0].name == "anthropic": - return providers[0].api_key + def get_anthropic_override_provider_id(self) -> Optional[str]: + """Helper function to fetch custom anthropic provider id for v0 BYOK feature""" + anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] + if len(anthropic_provider) != 0: + return anthropic_provider[0].id + return None + + @enforce_types + def get_anthropic_override_key(self) -> Optional[str]: + """Helper function to fetch custom anthropic key for v0 BYOK feature""" + anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] + if len(anthropic_provider) != 0: + return anthropic_provider[0].api_key return None diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py new file mode 100644 index 00000000..cbeee458 --- /dev/null +++ b/letta/services/step_manager.py @@ -0,0 +1,87 @@ +from typing import List, Literal, Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from letta.orm.errors import NoResultFound +from letta.orm.job import Job as JobModel +from letta.orm.sqlalchemy_base import AccessType +from letta.orm.step import Step as StepModel +from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.step import Step as PydanticStep +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types + + +class StepManager: + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def log_step( + self, + actor: PydanticUser, + provider_name: str, + model: str, + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + ) -> PydanticStep: + step_data = { + "origin": None, + "organization_id": actor.organization_id, + "provider_id": provider_id, + "provider_name": provider_name, + "model": model, + "context_window_limit": context_window_limit, + "completion_tokens": usage.completion_tokens, + "prompt_tokens": usage.prompt_tokens, + "total_tokens": usage.total_tokens, + "job_id": job_id, + "tags": [], + "tid": None, + } + with self.session_maker() as session: + if job_id: + self._verify_job_access(session, job_id, actor, access=["write"]) + new_step = StepModel(**step_data) + new_step.create(session) + return new_step.to_pydantic() + + @enforce_types + def get_step(self, step_id: str) -> PydanticStep: + with self.session_maker() as session: + step = StepModel.read(db_session=session, identifier=step_id) + return step.to_pydantic() + + def _verify_job_access( + self, + session: Session, + job_id: str, + actor: PydanticUser, + access: List[Literal["read", "write", "delete"]] = ["read"], + ) -> JobModel: + """ + Verify that a job exists and the user has the required access. + + Args: + session: The database session + job_id: The ID of the job to verify + actor: The user making the request + + Returns: + The job if it exists and the user has access + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + job_query = select(JobModel).where(JobModel.id == job_id) + job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER) + job = session.execute(job_query).scalar_one_or_none() + if not job: + raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") + return job diff --git a/tests/test_managers.py b/tests/test_managers.py index efe736f6..c68a395c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -26,6 +26,7 @@ from letta.orm import ( Source, SourcePassage, SourcesAgents, + Step, Tool, ToolsAgents, User, @@ -46,6 +47,7 @@ from letta.schemas.letta_request import LettaRequestConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate +from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.passage import Passage as PydanticPassage @@ -56,7 +58,6 @@ from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate from letta.schemas.tool_rule import InitToolRule -from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate from letta.server.server import SyncServer @@ -100,6 +101,7 @@ def clear_tables(server: SyncServer): session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(Agent)) session.execute(delete(User)) # Clear all records from the user table + session.execute(delete(Step)) session.execute(delete(Provider)) session.execute(delete(Organization)) # Clear all records from the organization table session.commit() # Commit the deletion @@ -2835,17 +2837,19 @@ def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser, def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user): """Test adding and retrieving job usage statistics.""" job_manager = server.job_manager + step_manager = server.step_manager # Add usage statistics - job_manager.add_job_usage( + step_manager.log_step( + provider_name="openai", + model="gpt-4", + context_window_limit=8192, job_id=default_job.id, - usage=LettaUsageStatistics( + usage=UsageStatistics( completion_tokens=100, prompt_tokens=50, total_tokens=150, - step_count=5, ), - step_id="step_1", actor=default_user, ) @@ -2874,30 +2878,33 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user): """Test adding multiple usage statistics entries for a job.""" job_manager = server.job_manager + step_manager = server.step_manager # Add first usage statistics entry - job_manager.add_job_usage( + step_manager.log_step( + provider_name="openai", + model="gpt-4", + context_window_limit=8192, job_id=default_job.id, - usage=LettaUsageStatistics( + usage=UsageStatistics( completion_tokens=100, prompt_tokens=50, total_tokens=150, - step_count=5, ), - step_id="step_1", actor=default_user, ) # Add second usage statistics entry - job_manager.add_job_usage( + step_manager.log_step( + provider_name="openai", + model="gpt-4", + context_window_limit=8192, job_id=default_job.id, - usage=LettaUsageStatistics( + usage=UsageStatistics( completion_tokens=200, prompt_tokens=100, total_tokens=300, - step_count=10, ), - step_id="step_2", actor=default_user, ) @@ -2905,9 +2912,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u usage_stats = job_manager.get_job_usage(job_id=default_job.id, actor=default_user) # Verify we get the most recent statistics - assert usage_stats.completion_tokens == 200 - assert usage_stats.prompt_tokens == 100 - assert usage_stats.total_tokens == 300 + assert usage_stats.completion_tokens == 300 + assert usage_stats.prompt_tokens == 150 + assert usage_stats.total_tokens == 450 + assert usage_stats.step_count == 2 def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): @@ -2920,18 +2928,19 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user): """Test adding usage statistics for a nonexistent job.""" - job_manager = server.job_manager + step_manager = server.step_manager with pytest.raises(NoResultFound): - job_manager.add_job_usage( + step_manager.log_step( + provider_name="openai", + model="gpt-4", + context_window_limit=8192, job_id="nonexistent_job", - usage=LettaUsageStatistics( + usage=UsageStatistics( completion_tokens=100, prompt_tokens=50, total_tokens=150, - step_count=5, ), - step_id="step_1", actor=default_user, ) diff --git a/tests/test_server.py b/tests/test_server.py index 763400b6..b732e95b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,15 +1,19 @@ import json +import os import uuid import warnings from typing import List, Tuple import pytest +from sqlalchemy import delete import letta.utils as utils from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.orm import Provider, Step from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage +from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.user import User utils.DEBUG = True @@ -277,6 +281,10 @@ def org_id(server): yield org.id # cleanup + with server.organization_manager.session_maker() as session: + session.execute(delete(Step)) + session.execute(delete(Provider)) + session.commit() server.organization_manager.delete_organization_by_id(org.id) @@ -1098,3 +1106,72 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to request.tool_ids = [b.id for b in base_tools[:-2]] agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor) assert len(agent_state.tools) == len(base_tools) - 2 + + +def test_messages_with_provider_override(server: SyncServer, user_id: str): + actor = server.user_manager.get_user_or_default(user_id) + provider = server.provider_manager.create_provider( + provider=PydanticProvider( + name="anthropic", + api_key=os.getenv("ANTHROPIC_API_KEY"), + ), + actor=actor, + ) + agent = server.create_agent( + request=CreateAgent( + memory_blocks=[], llm="anthropic/claude-3-opus-20240229", context_window_limit=200000, embedding="openai/text-embedding-ada-002" + ), + actor=actor, + ) + + existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor) + + usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message") + assert usage, "Sending message failed" + + get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id) + assert len(get_messages_response) > 0, "Retrieving messages failed" + + step_ids = set([msg.step_id for msg in get_messages_response]) + completion_tokens, prompt_tokens, total_tokens = 0, 0, 0 + for step_id in step_ids: + step = server.step_manager.get_step(step_id=step_id) + assert step, "Step was not logged correctly" + assert step.provider_id == provider.id + assert step.provider_name == agent.llm_config.model_endpoint_type + assert step.model == agent.llm_config.model + assert step.context_window_limit == agent.llm_config.context_window + completion_tokens += int(step.completion_tokens) + prompt_tokens += int(step.prompt_tokens) + total_tokens += int(step.total_tokens) + + assert completion_tokens == usage.completion_tokens + assert prompt_tokens == usage.prompt_tokens + assert total_tokens == usage.total_tokens + + server.provider_manager.delete_provider_by_id(provider.id) + + existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor) + + usage = server.user_message(user_id=actor.id, agent_id=agent.id, message="Test message") + assert usage, "Sending message failed" + + get_messages_response = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor, cursor=existing_messages[-1].id) + assert len(get_messages_response) > 0, "Retrieving messages failed" + + step_ids = set([msg.step_id for msg in get_messages_response]) + completion_tokens, prompt_tokens, total_tokens = 0, 0, 0 + for step_id in step_ids: + step = server.step_manager.get_step(step_id=step_id) + assert step, "Step was not logged correctly" + assert step.provider_id == None + assert step.provider_name == agent.llm_config.model_endpoint_type + assert step.model == agent.llm_config.model + assert step.context_window_limit == agent.llm_config.context_window + completion_tokens += int(step.completion_tokens) + prompt_tokens += int(step.prompt_tokens) + total_tokens += int(step.total_tokens) + + assert completion_tokens == usage.completion_tokens + assert prompt_tokens == usage.prompt_tokens + assert total_tokens == usage.total_tokens