From 9f7c533765b73a6237964e3f3717d02a18b6debe Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 12 Aug 2025 15:50:00 -0700 Subject: [PATCH] feat: add azure byok (#3884) --- ...241fc_add_api_version_to_byok_providers.py | 31 ++++++++++++++++ letta/llm_api/azure_client.py | 37 +++++++++++++++---- letta/orm/provider.py | 1 + letta/schemas/providers/azure.py | 9 +++++ letta/schemas/providers/base.py | 7 ++++ letta/services/provider_manager.py | 24 ++++++++++++ 6 files changed, 101 insertions(+), 8 deletions(-) create mode 100644 alembic/versions/ffb17eb241fc_add_api_version_to_byok_providers.py diff --git a/alembic/versions/ffb17eb241fc_add_api_version_to_byok_providers.py b/alembic/versions/ffb17eb241fc_add_api_version_to_byok_providers.py new file mode 100644 index 00000000..28c2a288 --- /dev/null +++ b/alembic/versions/ffb17eb241fc_add_api_version_to_byok_providers.py @@ -0,0 +1,31 @@ +"""add api version to byok providers + +Revision ID: ffb17eb241fc +Revises: 5fb8bba2c373 +Create Date: 2025-08-12 14:35:26.375985 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "ffb17eb241fc" +down_revision: Union[str, None] = "5fb8bba2c373" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("providers", sa.Column("api_version", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("providers", "api_version") + # ### end Alembic commands ### diff --git a/letta/llm_api/azure_client.py b/letta/llm_api/azure_client.py index 95468896..423be531 100644 --- a/letta/llm_api/azure_client.py +++ b/letta/llm_api/azure_client.py @@ -7,22 +7,41 @@ from openai.types.chat.chat_completion import ChatCompletion from letta.llm_api.openai_client import OpenAIClient from letta.otel.tracing import trace_method from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.settings import model_settings class AzureClient(OpenAIClient): + def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: + if llm_config.provider_category == ProviderCategory.byok: + from letta.services.provider_manager import ProviderManager + + return ProviderManager().get_azure_credentials(llm_config.provider_name, actor=self.actor) + + return None, None, None + + async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: + if llm_config.provider_category == ProviderCategory.byok: + from letta.services.provider_manager import ProviderManager + + return await ProviderManager().get_azure_credentials_async(llm_config.provider_name, actor=self.actor) + + return None, None, None + @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying synchronous request to OpenAI API and returns raw response dict. """ - api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") - base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") - api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") - client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + api_key, base_url, api_version = self.get_byok_overrides(llm_config) + if not api_key or not base_url or not api_version: + api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") + base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") + api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) response: ChatCompletion = client.chat.completions.create(**request_data) return response.model_dump() @@ -31,11 +50,13 @@ class AzureClient(OpenAIClient): """ Performs underlying asynchronous request to OpenAI API and returns raw response dict. """ - api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") - base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") - api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") - client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config) + if not api_key or not base_url or not api_version: + api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") + base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") + api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) response: ChatCompletion = await client.chat.completions.create(**request_data) return response.model_dump() diff --git a/letta/orm/provider.py b/letta/orm/provider.py index 1237ea9d..b46a95b8 100644 --- a/letta/orm/provider.py +++ b/letta/orm/provider.py @@ -31,6 +31,7 @@ class Provider(SqlalchemyBase, OrganizationMixin): base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.") access_key: Mapped[str] = mapped_column(nullable=True, doc="Access key used for requests to the provider.") region: Mapped[str] = mapped_column(nullable=True, doc="Region used for requests to the provider.") + api_version: Mapped[str] = mapped_column(nullable=True, doc="API version used for requests to the provider.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") diff --git a/letta/schemas/providers/azure.py b/letta/schemas/providers/azure.py index e51c1775..21b9e1d4 100644 --- a/letta/schemas/providers/azure.py +++ b/letta/schemas/providers/azure.py @@ -78,3 +78,12 @@ class AzureProvider(Provider): # Hard coded as there are no API endpoints for this llm_default = LLM_MAX_TOKENS.get(model_name, 4096) return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default) + + async def check_api_key(self): + if not self.api_key: + raise ValueError("No API key provided") + + try: + await self.list_llm_models_async() + except Exception as e: + raise LLMAuthenticationError(message=f"Failed to authenticate with Azure: {e}", code=ErrorCode.UNAUTHENTICATED) diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py index eef2cb39..2ecbe2b3 100644 --- a/letta/schemas/providers/base.py +++ b/letta/schemas/providers/base.py @@ -24,6 +24,7 @@ class Provider(ProviderBase): base_url: str | None = Field(None, description="Base URL for the provider.") access_key: str | None = Field(None, description="Access key used for requests to the provider.") region: str | None = Field(None, description="Region used for requests to the provider.") + api_version: str | None = Field(None, description="API version used for requests to the provider.") organization_id: str | None = Field(None, description="The organization id of the user") updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.") @@ -186,12 +187,16 @@ class ProviderCreate(ProviderBase): api_key: str = Field(..., description="API key or secret key used for requests to the provider.") access_key: str | None = Field(None, description="Access key used for requests to the provider.") region: str | None = Field(None, description="Region used for requests to the provider.") + base_url: str | None = Field(None, description="Base URL used for requests to the provider.") + api_version: str | None = Field(None, description="API version used for requests to the provider.") class ProviderUpdate(ProviderBase): api_key: str = Field(..., description="API key or secret key used for requests to the provider.") access_key: str | None = Field(None, description="Access key used for requests to the provider.") region: str | None = Field(None, description="Region used for requests to the provider.") + base_url: str | None = Field(None, description="Base URL used for requests to the provider.") + api_version: str | None = Field(None, description="API version used for requests to the provider.") class ProviderCheck(BaseModel): @@ -199,3 +204,5 @@ class ProviderCheck(BaseModel): api_key: str = Field(..., description="API key or secret key used for requests to the provider.") access_key: str | None = Field(None, description="Access key used for requests to the provider.") region: str | None = Field(None, description="Region used for requests to the provider.") + base_url: str | None = Field(None, description="Base URL used for requests to the provider.") + api_version: str | None = Field(None, description="API version used for requests to the provider.") diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 46fe615e..cfb32a82 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -205,6 +205,28 @@ class ProviderManager: region = providers[0].region if providers else None return access_key, secret_key, region + @enforce_types + @trace_method + def get_azure_credentials( + self, provider_name: Union[str, None], actor: PydanticUser + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + providers = self.list_providers(name=provider_name, actor=actor) + api_key = providers[0].api_key if providers else None + base_url = providers[0].base_url if providers else None + api_version = providers[0].api_version if providers else None + return api_key, base_url, api_version + + @enforce_types + @trace_method + async def get_azure_credentials_async( + self, provider_name: Union[str, None], actor: PydanticUser + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + providers = await self.list_providers_async(name=provider_name, actor=actor) + api_key = providers[0].api_key if providers else None + base_url = providers[0].base_url if providers else None + api_version = providers[0].api_version if providers else None + return api_key, base_url, api_version + @enforce_types @trace_method async def check_provider_api_key(self, provider_check: ProviderCheck) -> None: @@ -215,6 +237,8 @@ class ProviderManager: provider_category=ProviderCategory.byok, access_key=provider_check.access_key, # This contains the access key ID for Bedrock region=provider_check.region, + base_url=provider_check.base_url, + api_version=provider_check.api_version, ).cast_to_subtype() # TODO: add more string sanity checks here before we hit actual endpoints