feat: add azure byok (#3884)
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""add api version to byok providers
|
||||
|
||||
Revision ID: ffb17eb241fc
|
||||
Revises: 5fb8bba2c373
|
||||
Create Date: 2025-08-12 14:35:26.375985
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ffb17eb241fc"
|
||||
down_revision: Union[str, None] = "5fb8bba2c373"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("providers", sa.Column("api_version", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("providers", "api_version")
|
||||
# ### end Alembic commands ###
|
||||
@@ -7,22 +7,41 @@ from openai.types.chat.chat_completion import ChatCompletion
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class AzureClient(OpenAIClient):
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
return ProviderManager().get_azure_credentials(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
return None, None, None
|
||||
|
||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
return await ProviderManager().get_azure_credentials_async(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
return None, None, None
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@@ -31,11 +50,13 @@ class AzureClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
|
||||
access_key: Mapped[str] = mapped_column(nullable=True, doc="Access key used for requests to the provider.")
|
||||
region: Mapped[str] = mapped_column(nullable=True, doc="Region used for requests to the provider.")
|
||||
api_version: Mapped[str] = mapped_column(nullable=True, doc="API version used for requests to the provider.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
||||
|
||||
@@ -78,3 +78,12 @@ class AzureProvider(Provider):
|
||||
# Hard coded as there are no API endpoints for this
|
||||
llm_default = LLM_MAX_TOKENS.get(model_name, 4096)
|
||||
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
|
||||
|
||||
async def check_api_key(self):
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
try:
|
||||
await self.list_llm_models_async()
|
||||
except Exception as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Azure: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
|
||||
@@ -24,6 +24,7 @@ class Provider(ProviderBase):
|
||||
base_url: str | None = Field(None, description="Base URL for the provider.")
|
||||
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
||||
region: str | None = Field(None, description="Region used for requests to the provider.")
|
||||
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
||||
organization_id: str | None = Field(None, description="The organization id of the user")
|
||||
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
|
||||
|
||||
@@ -186,12 +187,16 @@ class ProviderCreate(ProviderBase):
|
||||
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
||||
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
||||
region: str | None = Field(None, description="Region used for requests to the provider.")
|
||||
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
|
||||
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
||||
|
||||
|
||||
class ProviderUpdate(ProviderBase):
|
||||
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
||||
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
||||
region: str | None = Field(None, description="Region used for requests to the provider.")
|
||||
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
|
||||
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
||||
|
||||
|
||||
class ProviderCheck(BaseModel):
|
||||
@@ -199,3 +204,5 @@ class ProviderCheck(BaseModel):
|
||||
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
||||
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
||||
region: str | None = Field(None, description="Region used for requests to the provider.")
|
||||
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
|
||||
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
||||
|
||||
@@ -205,6 +205,28 @@ class ProviderManager:
|
||||
region = providers[0].region if providers else None
|
||||
return access_key, secret_key, region
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_azure_credentials(
|
||||
self, provider_name: Union[str, None], actor: PydanticUser
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
providers = self.list_providers(name=provider_name, actor=actor)
|
||||
api_key = providers[0].api_key if providers else None
|
||||
base_url = providers[0].base_url if providers else None
|
||||
api_version = providers[0].api_version if providers else None
|
||||
return api_key, base_url, api_version
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_azure_credentials_async(
|
||||
self, provider_name: Union[str, None], actor: PydanticUser
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
||||
api_key = providers[0].api_key if providers else None
|
||||
base_url = providers[0].base_url if providers else None
|
||||
api_version = providers[0].api_version if providers else None
|
||||
return api_key, base_url, api_version
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
||||
@@ -215,6 +237,8 @@ class ProviderManager:
|
||||
provider_category=ProviderCategory.byok,
|
||||
access_key=provider_check.access_key, # This contains the access key ID for Bedrock
|
||||
region=provider_check.region,
|
||||
base_url=provider_check.base_url,
|
||||
api_version=provider_check.api_version,
|
||||
).cast_to_subtype()
|
||||
|
||||
# TODO: add more string sanity checks here before we hit actual endpoints
|
||||
|
||||
Reference in New Issue
Block a user