From 4047f02386a55bdd0cdaa1c8e12224b2b7f6d649 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 7 Jan 2025 22:12:55 -0800 Subject: [PATCH] feat: support custom api keys for cloud (#533) --- .../915b68780108_add_providers_data_to_orm.py | 47 ++++++++++++ letta/llm_api/llm_api_tools.py | 6 +- letta/orm/__init__.py | 1 + letta/orm/organization.py | 2 + letta/orm/provider.py | 23 ++++++ letta/providers.py | 23 +++++- letta/server/rest_api/routers/v1/providers.py | 72 +++++++++++++++++++ letta/server/server.py | 12 +++- letta/services/provider_manager.py | 63 ++++++++++++++++ 9 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 alembic/versions/915b68780108_add_providers_data_to_orm.py create mode 100644 letta/orm/provider.py create mode 100644 letta/server/rest_api/routers/v1/providers.py create mode 100644 letta/services/provider_manager.py diff --git a/alembic/versions/915b68780108_add_providers_data_to_orm.py b/alembic/versions/915b68780108_add_providers_data_to_orm.py new file mode 100644 index 00000000..973b8dbb --- /dev/null +++ b/alembic/versions/915b68780108_add_providers_data_to_orm.py @@ -0,0 +1,47 @@ +"""Add providers data to ORM + +Revision ID: 915b68780108 +Revises: 400501b04bf0 +Create Date: 2025-01-07 10:49:04.717058 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "915b68780108" +down_revision: Union[str, None] = "400501b04bf0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "providers", + sa.Column("name", sa.String(), nullable=False), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("providers") + # ### end Alembic commands ### diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index d83e8699..fc5252d3 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -22,6 +22,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from letta.services.provider_manager import ProviderManager from letta.settings import ModelSettings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface @@ -251,9 +252,12 @@ def create( tool_call = {"type": "function", "function": {"name": force_tool_call}} assert functions is not None + # load anthropic key from db in case a custom key has been stored + anthropic_key_override = ProviderManager().get_anthropic_key_override() + return anthropic_chat_completions_request( url=llm_config.model_endpoint, - api_key=model_settings.anthropic_api_key, + api_key=anthropic_key_override if anthropic_key_override else model_settings.anthropic_api_key, data=ChatCompletionRequest( model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 437956a5..f5f0e478 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -8,6 +8,7 @@ from letta.orm.job import Job from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.passage import AgentPassage, BasePassage, SourcePassage +from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 486cfcc4..cef5adbd 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from letta.orm.agent import Agent from letta.orm.file import FileMetadata + from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable from letta.orm.tool import Tool from letta.orm.user import User @@ -45,6 +46,7 @@ class Organization(SqlalchemyBase): "SourcePassage", back_populates="organization", cascade="all, delete-orphan" ) agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan") + providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan") @property def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: diff --git a/letta/orm/provider.py b/letta/orm/provider.py new file mode 100644 index 00000000..82c84c5a --- /dev/null +++ b/letta/orm/provider.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.providers import Provider as PydanticProvider + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class Provider(SqlalchemyBase, OrganizationMixin): + """Provider ORM class""" + + __tablename__ = "providers" + __pydantic_model__ = PydanticProvider + + name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider") + api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") diff --git a/letta/providers.py b/letta/providers.py index 87a7557d..deacb0a3 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -1,16 +1,24 @@ from typing import List, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import Field, model_validator from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig +from letta.services.organization_manager import OrganizationManager -class Provider(BaseModel): +class ProviderBase(LettaBase): + __id_prefix__ = "provider" + + +class Provider(ProviderBase): name: str = Field(..., description="The name of the provider") + api_key: Optional[str] = Field(None, description="API key used for requests to the provider.") + organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user") def list_llm_models(self) -> List[LLMConfig]: return [] @@ -29,6 +37,17 @@ class Provider(BaseModel): return f"{self.name}/{model_name}" +class ProviderCreate(ProviderBase): + name: str = Field(..., description="The name of the provider.") + api_key: str = Field(..., description="API key used for requests to the provider.") + organization_id: str = Field(..., description="The organization id that this provider information pertains to.") + + +class ProviderUpdate(ProviderBase): + id: str = Field(..., description="The id of the provider to update.") + api_key: str = Field(..., description="API key used for requests to the provider.") + + class LettaProvider(Provider): name: str = "letta" diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py new file mode 100644 index 00000000..9ad9022e --- /dev/null +++ b/letta/server/rest_api/routers/v1/providers.py @@ -0,0 +1,72 @@ +from fastapi import APIRouter, Depends + +from letta.providers import Provider, ProviderCreate, ProviderUpdate +from letta.server.rest_api.utils import get_letta_server + +if TYPE_CHECKING: + from letta.server.server import SyncServer + +router = APIRouter(prefix="/providers", tags=["providers", "admin"]) + + +@router.get("/", tags=["admin"], response_model=List[Provider], operation_id="list_providers") +def list_providers( + cursor: Optional[str] = Query(None), + limit: Optional[int] = Query(50), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Get a list of all custom providers in the database + """ + try: + providers = server.provider_manager.list_providers(cursor=cursor, limit=limit) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return providers + + +@router.post("/", tags=["admin"], response_model=Provider, operation_id="create_provider") +def create_provider( + request: ProviderCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Create a new custom provider + """ + provider = Provider(**request.model_dump()) + provider = server.provider_manager.create_provider(provider) + return provider + + +@router.put("/", tags=["admin"], response_model=Provider, operation_id="update_provider") +def update_provider( + request: ProviderUpdate = Body(...), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Update an existing custom provider + """ + provider = server.provider_manager.update_provider(request) + return provider + + +@router.delete("/", tags=["admin"], response_model=Provider, operation_id="delete_provider") +def delete_provider( + provider_id: str = Query(..., description="The provider_id key to be deleted."), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Delete an existing custom provider + """ + try: + provider = server.provider_manager.get_provider_by_id(provider_id=provider_id) + if provider is None: + raise HTTPException(status_code=404, detail=f"Provider does not exist") + server.provider_manager.delete_provider_by_id(provider_id=provider_id) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return user diff --git a/letta/server/server.py b/letta/server/server.py index 932e3c2c..4b86544a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -66,6 +66,7 @@ from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager from letta.services.passage_manager import PassageManager from letta.services.per_agent_lock_manager import PerAgentLockManager +from letta.services.provider_manager import ProviderManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox @@ -290,6 +291,7 @@ class SyncServer(Server): self.message_manager = MessageManager() self.job_manager = JobManager() self.agent_manager = AgentManager() + self.provider_manager = ProviderManager() # Managers that interface with parallelism self.per_agent_lock_manager = PerAgentLockManager() @@ -1030,7 +1032,7 @@ class SyncServer(Server): """List available models""" llm_models = [] - for provider in self._enabled_providers: + for provider in self.get_enabled_providers(): try: llm_models.extend(provider.list_llm_models()) except Exception as e: @@ -1040,13 +1042,19 @@ class SyncServer(Server): def list_embedding_models(self) -> List[EmbeddingConfig]: """List available embedding models""" embedding_models = [] - for provider in self._enabled_providers: + for provider in self.get_enabled_providers(): try: embedding_models.extend(provider.list_embedding_models()) except Exception as e: warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models + def get_enabled_providers(self): + providers_from_env = {p.name: p for p in self._enabled_providers} + providers_from_db = {p.name: p for p in self.provider_manager.list_providers()} + # Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur + return {**providers_from_env, **providers_from_db}.values() + def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig: provider_name, model_name = handle.split("/", 1) provider = self.get_provider_from_name(provider_name) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py new file mode 100644 index 00000000..8fe9797b --- /dev/null +++ b/letta/services/provider_manager.py @@ -0,0 +1,63 @@ +from typing import List, Optional + +from letta.orm.provider import Provider as ProviderModel +from letta.providers import Provider as PydanticProvider +from letta.providers import ProviderUpdate +from letta.utils import enforce_types + + +class ProviderManager: + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def create_provider(self, provider: PydanticProvider) -> PydanticProvider: + """Create a new provider if it doesn't already exist.""" + with self.session_maker() as session: + new_provider = ProviderModel(**provider.model_dump()) + new_provider.create(session) + return new_provider.to_pydantic() + + @enforce_types + def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider: + """Update provider details.""" + with self.session_maker() as session: + # Retrieve the existing provider by ID + existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id) + + # Update only the fields that are provided in ProviderUpdate + update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(existing_provider, key, value) + + # Commit the updated provider + existing_provider.update(session) + return existing_provider.to_pydantic() + + @enforce_types + def delete_provider_by_id(self, provider_id: str): + """Delete a provider.""" + with self.session_maker() as session: + # Delete from provider table + provider = ProviderModel.read(db_session=session, identifier=provider_id) + provider.hard_delete(session) + + session.commit() + + @enforce_types + def list_providers(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]: + """List providers with pagination using cursor (id) and limit.""" + with self.session_maker() as session: + results = ProviderModel.list(db_session=session, cursor=cursor, limit=limit) + return [provider.to_pydantic() for provider in results] + + @enforce_types + def get_anthropic_key_override(self) -> Optional[str]: + """Helper function to fetch custom anthropic key for v0 BYOK feature""" + providers = self.list_providers(limit=1) + if len(providers) == 1 and providers[0].name == "anthropic": + return providers[0].api_key + return None