From 89fef2d45de575d2cbe8daab931bb0ed0f2e7ad8 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 5 Feb 2025 11:16:33 -0800 Subject: [PATCH] feat: add model_endpoint to steps table (#902) --- ...210ca_add_model_endpoint_to_steps_table.py | 31 +++++++++++++++++++ letta/agent.py | 1 + letta/orm/step.py | 1 + letta/schemas/step.py | 1 + letta/services/step_manager.py | 2 ++ tests/test_managers.py | 4 +++ 6 files changed, 40 insertions(+) create mode 100644 alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py diff --git a/alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py b/alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py new file mode 100644 index 00000000..df3b4278 --- /dev/null +++ b/alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py @@ -0,0 +1,31 @@ +"""add model endpoint to steps table + +Revision ID: dfafcf8210ca +Revises: f922ca16e42c +Create Date: 2025-02-04 16:45:34.132083 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "dfafcf8210ca" +down_revision: Union[str, None] = "f922ca16e42c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("model_endpoint", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "model_endpoint") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 4284e86f..c84f80e9 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -790,6 +790,7 @@ class Agent(BaseAgent): actor=self.user, provider_name=self.agent_state.llm_config.model_endpoint_type, model=self.agent_state.llm_config.model, + model_endpoint=self.agent_state.llm_config.model_endpoint, context_window_limit=self.agent_state.llm_config.context_window, usage=response.usage, # TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature diff --git a/letta/orm/step.py b/letta/orm/step.py index 8ea5f313..e5c33347 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -35,6 +35,7 @@ class Step(SqlalchemyBase): ) provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.") model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") + model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.") context_window_limit: Mapped[Optional[int]] = mapped_column( None, nullable=True, doc="The context window limit configured for this step." ) diff --git a/letta/schemas/step.py b/letta/schemas/step.py index c3482878..98bc51c7 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -20,6 +20,7 @@ class Step(StepBase): ) provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.") model: Optional[str] = Field(None, description="The name of the model used for this step.") + model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.") context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.") completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.") prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.") diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 49dbf316..a316eda6 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -55,6 +55,7 @@ class StepManager: actor: PydanticUser, provider_name: str, model: str, + model_endpoint: Optional[str], context_window_limit: int, usage: UsageStatistics, provider_id: Optional[str] = None, @@ -66,6 +67,7 @@ class StepManager: "provider_id": provider_id, "provider_name": provider_name, "model": model, + "model_endpoint": model_endpoint, "context_window_limit": context_window_limit, "completion_tokens": usage.completion_tokens, "prompt_tokens": usage.prompt_tokens, diff --git a/tests/test_managers.py b/tests/test_managers.py index a4d8adce..43ffbaa7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3128,6 +3128,7 @@ def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_us step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id=default_job.id, usage=UsageStatistics( @@ -3169,6 +3170,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id=default_job.id, usage=UsageStatistics( @@ -3183,6 +3185,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id=default_job.id, usage=UsageStatistics( @@ -3219,6 +3222,7 @@ def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user): step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id="nonexistent_job", usage=UsageStatistics(