feat: add azure byok (#3884)

This commit is contained in:
cthomas
2025-08-12 15:50:00 -07:00
committed by GitHub
parent e2e91c7260
commit 78cfb4902d
6 changed files with 101 additions and 8 deletions

View File

@@ -0,0 +1,31 @@
"""add api version to byok providers
Revision ID: ffb17eb241fc
Revises: 5fb8bba2c373
Create Date: 2025-08-12 14:35:26.375985
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "ffb17eb241fc"
down_revision: Union[str, None] = "5fb8bba2c373"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("api_version", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("providers", "api_version")
# ### end Alembic commands ###

View File

@@ -7,22 +7,41 @@ from openai.types.chat.chat_completion import ChatCompletion
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
class AzureClient(OpenAIClient):
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
return ProviderManager().get_azure_credentials(llm_config.provider_name, actor=self.actor)
return None, None, None
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
return await ProviderManager().get_azure_credentials_async(llm_config.provider_name, actor=self.actor)
return None, None, None
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@@ -31,11 +50,13 @@ class AzureClient(OpenAIClient):
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()

View File

@@ -31,6 +31,7 @@ class Provider(SqlalchemyBase, OrganizationMixin):
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
access_key: Mapped[str] = mapped_column(nullable=True, doc="Access key used for requests to the provider.")
region: Mapped[str] = mapped_column(nullable=True, doc="Region used for requests to the provider.")
api_version: Mapped[str] = mapped_column(nullable=True, doc="API version used for requests to the provider.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@@ -78,3 +78,12 @@ class AzureProvider(Provider):
# Hard coded as there are no API endpoints for this
llm_default = LLM_MAX_TOKENS.get(model_name, 4096)
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
async def check_api_key(self):
if not self.api_key:
raise ValueError("No API key provided")
try:
await self.list_llm_models_async()
except Exception as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with Azure: {e}", code=ErrorCode.UNAUTHENTICATED)

View File

@@ -24,6 +24,7 @@ class Provider(ProviderBase):
base_url: str | None = Field(None, description="Base URL for the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")
organization_id: str | None = Field(None, description="The organization id of the user")
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
@@ -186,12 +187,16 @@ class ProviderCreate(ProviderBase):
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")
class ProviderUpdate(ProviderBase):
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")
class ProviderCheck(BaseModel):
@@ -199,3 +204,5 @@ class ProviderCheck(BaseModel):
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")

View File

@@ -205,6 +205,28 @@ class ProviderManager:
region = providers[0].region if providers else None
return access_key, secret_key, region
@enforce_types
@trace_method
def get_azure_credentials(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = self.list_providers(name=provider_name, actor=actor)
api_key = providers[0].api_key if providers else None
base_url = providers[0].base_url if providers else None
api_version = providers[0].api_version if providers else None
return api_key, base_url, api_version
@enforce_types
@trace_method
async def get_azure_credentials_async(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
api_key = providers[0].api_key if providers else None
base_url = providers[0].base_url if providers else None
api_version = providers[0].api_version if providers else None
return api_key, base_url, api_version
@enforce_types
@trace_method
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
@@ -215,6 +237,8 @@ class ProviderManager:
provider_category=ProviderCategory.byok,
access_key=provider_check.access_key, # This contains the access key ID for Bedrock
region=provider_check.region,
base_url=provider_check.base_url,
api_version=provider_check.api_version,
).cast_to_subtype()
# TODO: add more string sanity checks here before we hit actual endpoints