feat: byok provider models in db also (#8317)
* feat: byok provider models in db also * make tests and sync api * fix inconsistent state with recreating provider of same name * fix sync on byok creation * update revision * move stripe code for testing purposes * revert * add refresh byok models endpoint * just stage publish api * add tests * reorder revision * add test for name clashes
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""last_synced column for providers
|
||||
|
||||
Revision ID: 308a180244fc
|
||||
Revises: 82feb220a9b8
|
||||
Create Date: 2026-01-05 18:54:15.996786
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "308a180244fc"
|
||||
down_revision: Union[str, None] = "82feb220a9b8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("providers", sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("providers", "last_synced")
|
||||
# ### end Alembic commands ###
|
||||
@@ -14826,6 +14826,53 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/providers/{provider_id}/refresh": {
|
||||
"patch": {
|
||||
"tags": ["providers"],
|
||||
"summary": "Refresh Provider Models",
|
||||
"description": "Refresh models for a BYOK provider by querying the provider's API.\nAdds new models and removes ones no longer available.",
|
||||
"operationId": "refresh_provider_models",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "provider_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"minLength": 45,
|
||||
"maxLength": 45,
|
||||
"pattern": "^provider-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
|
||||
"description": "The ID of the provider in the format 'provider-<uuid4>'",
|
||||
"examples": ["provider-123e4567-e89b-42d3-8456-426614174000"],
|
||||
"title": "Provider Id"
|
||||
},
|
||||
"description": "The ID of the provider in the format 'provider-<uuid4>'"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Provider"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/runs/": {
|
||||
"get": {
|
||||
"tags": ["runs"],
|
||||
@@ -38973,6 +39020,19 @@
|
||||
"title": "Updated At",
|
||||
"description": "The last update timestamp of the provider."
|
||||
},
|
||||
"last_synced": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Last Synced",
|
||||
"description": "The last time models were synced for this provider."
|
||||
},
|
||||
"api_key_enc": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
@@ -41,6 +42,11 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
api_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted API key or secret key for the provider.")
|
||||
access_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access key for the provider.")
|
||||
|
||||
# sync tracking
|
||||
last_synced: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, doc="Last time models were synced for this provider."
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
||||
models: Mapped[list["ProviderModel"]] = relationship("ProviderModel", back_populates="provider", cascade="all, delete-orphan")
|
||||
|
||||
@@ -32,6 +32,7 @@ class Provider(ProviderBase):
|
||||
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
||||
organization_id: str | None = Field(None, description="The organization id of the user")
|
||||
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
|
||||
last_synced: datetime | None = Field(None, description="The last time models were synced for this provider.")
|
||||
|
||||
# Encrypted fields (stored as Secret objects, serialized to strings for DB)
|
||||
# Secret class handles validation and serialization automatically via __get_pydantic_core_schema__
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
@@ -144,6 +144,27 @@ async def check_existing_provider(
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{provider_id}/refresh", response_model=Provider, operation_id="refresh_provider_models")
|
||||
async def refresh_provider_models(
|
||||
provider_id: ProviderId,
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Refresh models for a BYOK provider by querying the provider's API.
|
||||
Adds new models and removes ones no longer available.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
# Only allow refresh for BYOK providers
|
||||
if provider.provider_category != ProviderCategory.byok:
|
||||
raise HTTPException(status_code=400, detail="Refresh is only supported for BYOK providers")
|
||||
|
||||
await server.provider_manager._sync_default_models_for_provider(provider, actor)
|
||||
return await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
|
||||
async def delete_provider(
|
||||
provider_id: ProviderId,
|
||||
|
||||
@@ -466,6 +466,8 @@ class SyncServer(object):
|
||||
embedding_models=embedding_models,
|
||||
organization_id=None, # Global models
|
||||
)
|
||||
# Update last_synced timestamp
|
||||
await self.provider_manager.update_provider_last_synced_async(persisted_provider.id)
|
||||
logger.info(
|
||||
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
|
||||
)
|
||||
@@ -1177,7 +1179,7 @@ class SyncServer(object):
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url,
|
||||
context_window=model.max_context_window or 16384,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
@@ -1185,7 +1187,7 @@ class SyncServer(object):
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
# Get BYOK provider models - sync if not synced yet, then read from DB
|
||||
if include_byok:
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
@@ -1196,9 +1198,37 @@ class SyncServer(object):
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
llm_models.extend(models)
|
||||
# Sync models if not synced yet
|
||||
if provider.last_synced is None:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
embedding_models = await typed_provider.list_embedding_models_async()
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=provider.organization_id,
|
||||
)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id)
|
||||
|
||||
# Read from database
|
||||
provider_llm_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="llm",
|
||||
provider_id=provider.id,
|
||||
enabled=True,
|
||||
)
|
||||
for model in provider_llm_models:
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url,
|
||||
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
@@ -1240,7 +1270,7 @@ class SyncServer(object):
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
# Get BYOK provider models - sync if not synced yet, then read from DB
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
@@ -1248,9 +1278,36 @@ class SyncServer(object):
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_embedding_models_async()
|
||||
embedding_models.extend(models)
|
||||
# Sync models if not synced yet
|
||||
if provider.last_synced is None:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
llm_models = await typed_provider.list_llm_models_async()
|
||||
emb_models = await typed_provider.list_embedding_models_async()
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=emb_models,
|
||||
organization_id=provider.organization_id,
|
||||
)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id)
|
||||
|
||||
# Read from database
|
||||
provider_embedding_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="embedding",
|
||||
provider_id=provider.id,
|
||||
enabled=True,
|
||||
)
|
||||
for model in provider_embedding_models:
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_model=model.name,
|
||||
embedding_endpoint_type=model.model_endpoint_type,
|
||||
embedding_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
embedding_dim=model.embedding_dim or 1536,
|
||||
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=model.handle,
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
|
||||
@@ -98,9 +98,30 @@ class ProviderManager:
|
||||
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
|
||||
|
||||
await deleted_provider.update_async(session, actor=actor)
|
||||
|
||||
# Also restore any soft-deleted models associated with this provider
|
||||
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
|
||||
# so soft-deleted models would block creation of new models with the same handle
|
||||
from sqlalchemy import update
|
||||
|
||||
restore_models_stmt = (
|
||||
update(ProviderModelORM)
|
||||
.where(
|
||||
and_(
|
||||
ProviderModelORM.provider_id == deleted_provider.id,
|
||||
ProviderModelORM.is_deleted == True,
|
||||
)
|
||||
)
|
||||
.values(is_deleted=False)
|
||||
)
|
||||
result = await session.execute(restore_models_stmt)
|
||||
if result.rowcount > 0:
|
||||
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
|
||||
|
||||
provider_pydantic = deleted_provider.to_pydantic()
|
||||
|
||||
# For BYOK providers, automatically sync available models
|
||||
# This will add any new models and remove any that are no longer available
|
||||
if is_byok:
|
||||
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
||||
|
||||
@@ -201,6 +222,17 @@ class ProviderManager:
|
||||
await existing_provider.update_async(session, actor=actor)
|
||||
return existing_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def update_provider_last_synced_async(self, provider_id: str) -> None:
|
||||
"""Update the last_synced timestamp for a provider."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=None)
|
||||
provider.last_synced = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@trace_method
|
||||
@@ -476,81 +508,19 @@ class ProviderManager:
|
||||
|
||||
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
|
||||
"""Sync models for a newly created BYOK provider by querying the provider's API."""
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
# Get the provider class and create an instance
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers.anthropic import AnthropicProvider
|
||||
from letta.schemas.providers.azure import AzureProvider
|
||||
from letta.schemas.providers.bedrock import BedrockProvider
|
||||
from letta.schemas.providers.google_gemini import GoogleAIProvider
|
||||
from letta.schemas.providers.groq import GroqProvider
|
||||
from letta.schemas.providers.ollama import OllamaProvider
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
from letta.schemas.providers.zai import ZAIProvider
|
||||
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
llm_models = await typed_provider.list_llm_models_async()
|
||||
embedding_models = await typed_provider.list_embedding_models_async()
|
||||
|
||||
# ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id
|
||||
# (needed for OAuth token refresh and database persistence)
|
||||
if provider.provider_type == ProviderType.chatgpt_oauth:
|
||||
provider_instance = provider.cast_to_subtype()
|
||||
else:
|
||||
provider_type_to_class = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"groq": GroqProvider,
|
||||
"google": GoogleAIProvider,
|
||||
"ollama": OllamaProvider,
|
||||
"bedrock": BedrockProvider,
|
||||
"azure": AzureProvider,
|
||||
"zai": ZAIProvider,
|
||||
}
|
||||
|
||||
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
|
||||
provider_class = provider_type_to_class.get(provider_type)
|
||||
|
||||
if not provider_class:
|
||||
logger.warning(f"No provider class found for type '{provider_type}'")
|
||||
return
|
||||
|
||||
# Create provider instance with necessary parameters
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
kwargs = {
|
||||
"name": provider.name,
|
||||
"api_key": api_key,
|
||||
"provider_category": provider.provider_category,
|
||||
}
|
||||
if provider.base_url:
|
||||
kwargs["base_url"] = provider.base_url
|
||||
if access_key:
|
||||
kwargs["access_key"] = access_key
|
||||
if provider.region:
|
||||
kwargs["region"] = provider.region
|
||||
if provider.api_version:
|
||||
kwargs["api_version"] = provider.api_version
|
||||
|
||||
provider_instance = provider_class(**kwargs)
|
||||
|
||||
# Query the provider's API for available models
|
||||
llm_models = await provider_instance.list_llm_models_async()
|
||||
embedding_models = await provider_instance.list_embedding_models_async()
|
||||
|
||||
# Update handles and provider_name for BYOK providers
|
||||
for model in llm_models:
|
||||
model.provider_name = provider.name
|
||||
model.handle = f"{provider.name}/{model.model}"
|
||||
model.provider_category = provider.provider_category
|
||||
|
||||
for model in embedding_models:
|
||||
model.handle = f"{provider.name}/{model.embedding_model}"
|
||||
|
||||
# Use existing sync_provider_models_async to save to database
|
||||
await self.sync_provider_models_async(
|
||||
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
|
||||
provider=provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
await self.update_provider_last_synced_async(provider.id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
|
||||
|
||||
@@ -2292,6 +2292,7 @@ async def test_server_list_llm_models_byok_from_provider_api(default_user, provi
|
||||
# Create a mock typed provider that returns our test models
|
||||
mock_typed_provider = MagicMock()
|
||||
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
|
||||
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
# Patch cast_to_subtype on the Provider class to return our mock
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
@@ -2525,3 +2526,620 @@ async def test_create_agent_with_byok_handle_dynamic_fetch(default_user, provide
|
||||
|
||||
# Cleanup
|
||||
await server.agent_manager.delete_agent_async(agent_id=agent.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_last_synced_triggers_sync_when_null(default_user, provider_manager):
|
||||
"""Test that BYOK providers with last_synced=null trigger a sync on first model listing."""
|
||||
from letta.schemas.providers import Provider
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a BYOK provider (last_synced will be null by default)
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-sync-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Verify last_synced is null initially
|
||||
assert byok_provider.last_synced is None
|
||||
|
||||
# Create server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# Mock the BYOK provider's list_llm_models_async to return test models
|
||||
mock_byok_models = [
|
||||
LLMConfig(
|
||||
model=f"byok-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=64000,
|
||||
handle=f"test-byok-sync-{test_id}/gpt-4o",
|
||||
provider_name=byok_provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
]
|
||||
|
||||
mock_typed_provider = MagicMock()
|
||||
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
|
||||
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
# List BYOK models - should trigger sync because last_synced is null
|
||||
byok_models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
# Verify sync was triggered (cast_to_subtype was called to fetch from API)
|
||||
# Note: may be called multiple times if other BYOK providers exist in DB
|
||||
mock_typed_provider.list_llm_models_async.assert_called()
|
||||
|
||||
# Verify last_synced was updated for our provider
|
||||
updated_providers = await provider_manager.list_providers_async(name=byok_provider.name, actor=default_user)
|
||||
assert len(updated_providers) == 1
|
||||
assert updated_providers[0].last_synced is not None
|
||||
|
||||
# Verify models were synced to database
|
||||
synced_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(synced_models) == 1
|
||||
assert synced_models[0].name == f"byok-gpt-4o-{test_id}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_last_synced_skips_sync_when_set(default_user, provider_manager):
|
||||
"""Test that BYOK providers with last_synced set skip sync and read from DB."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from letta.schemas.providers import Provider
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a BYOK provider
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-cached-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Manually sync models to DB
|
||||
cached_model = LLMConfig(
|
||||
model=f"cached-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=64000,
|
||||
handle=f"test-byok-cached-{test_id}/gpt-4o",
|
||||
provider_name=byok_provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=byok_provider,
|
||||
llm_models=[cached_model],
|
||||
embedding_models=[],
|
||||
organization_id=default_user.organization_id,
|
||||
)
|
||||
|
||||
# Set last_synced to indicate models are already synced
|
||||
await provider_manager.update_provider_last_synced_async(byok_provider.id)
|
||||
|
||||
# Create server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# Mock cast_to_subtype - should NOT be called since last_synced is set
|
||||
mock_typed_provider = MagicMock()
|
||||
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=[])
|
||||
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
# List BYOK models - should read from DB, not trigger sync
|
||||
byok_models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
# Verify sync was NOT triggered (cast_to_subtype should not be called)
|
||||
mock_typed_provider.list_llm_models_async.assert_not_called()
|
||||
|
||||
# Verify we got the cached model from DB
|
||||
byok_handles = [m.handle for m in byok_models]
|
||||
assert f"test-byok-cached-{test_id}/gpt-4o" in byok_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_provider_updates_last_synced_on_sync(default_user, provider_manager):
|
||||
"""Test that base provider sync updates the last_synced timestamp."""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a base provider
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-sync-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="", # Base providers don't store API keys
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
# Verify last_synced is null initially
|
||||
assert base_provider.last_synced is None
|
||||
|
||||
# Sync models for the base provider
|
||||
base_model = LLMConfig(
|
||||
model=f"base-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=64000,
|
||||
handle=f"test-base-sync-{test_id}/gpt-4o",
|
||||
)
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=base_provider,
|
||||
llm_models=[base_model],
|
||||
embedding_models=[],
|
||||
organization_id=None,
|
||||
)
|
||||
await provider_manager.update_provider_last_synced_async(base_provider.id)
|
||||
|
||||
# Verify last_synced was updated
|
||||
updated_providers = await provider_manager.list_providers_async(name=base_provider.name, actor=default_user)
|
||||
assert len(updated_providers) == 1
|
||||
assert updated_providers[0].last_synced is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_models_synced_on_creation(default_user, provider_manager):
|
||||
"""Test that models are automatically synced when a BYOK provider is created.
|
||||
|
||||
When create_provider_async is called with is_byok=True, it should:
|
||||
1. Create the provider in the database
|
||||
2. Call _sync_default_models_for_provider to fetch and persist models from the provider API
|
||||
3. Update last_synced timestamp
|
||||
"""
|
||||
from letta.schemas.providers import Provider
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Mock models that the provider API would return
|
||||
mock_llm_models = [
|
||||
LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-byok-creation-{test_id}/gpt-4o",
|
||||
provider_name=f"test-byok-creation-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
LLMConfig(
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-byok-creation-{test_id}/gpt-4o-mini",
|
||||
provider_name=f"test-byok-creation-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
]
|
||||
mock_embedding_models = [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=f"test-byok-creation-{test_id}/text-embedding-3-small",
|
||||
),
|
||||
]
|
||||
|
||||
# Create a mock typed provider that returns our test models
|
||||
mock_typed_provider = MagicMock()
|
||||
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_llm_models)
|
||||
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=mock_embedding_models)
|
||||
|
||||
# Patch cast_to_subtype to return our mock when _sync_default_models_for_provider is called
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
# Create the BYOK provider - this should automatically sync models
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-creation-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-test-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Verify the provider API was called during creation
|
||||
mock_typed_provider.list_llm_models_async.assert_called_once()
|
||||
mock_typed_provider.list_embedding_models_async.assert_called_once()
|
||||
|
||||
# Re-fetch the provider to get the updated last_synced value
|
||||
# (the returned object from create_provider_async is stale since last_synced is set after)
|
||||
byok_provider = await provider_manager.get_provider_async(byok_provider.id, default_user)
|
||||
|
||||
# Verify last_synced was set (indicating sync completed)
|
||||
assert byok_provider.last_synced is not None
|
||||
|
||||
# Verify LLM models were persisted to the database
|
||||
synced_llm_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(synced_llm_models) == 2
|
||||
synced_llm_names = {m.name for m in synced_llm_models}
|
||||
assert "gpt-4o" in synced_llm_names
|
||||
assert "gpt-4o-mini" in synced_llm_names
|
||||
|
||||
# Verify embedding models were persisted to the database
|
||||
synced_embedding_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="embedding",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(synced_embedding_models) == 1
|
||||
assert synced_embedding_models[0].name == "text-embedding-3-small"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_byok_provider_adds_new_models(default_user, provider_manager):
|
||||
"""Test that refreshing a BYOK provider adds new models from the provider API.
|
||||
|
||||
When _sync_default_models_for_provider is called (via refresh endpoint):
|
||||
1. It should fetch current models from the provider API
|
||||
2. Add any new models that weren't previously synced
|
||||
3. Update the last_synced timestamp
|
||||
"""
|
||||
from letta.schemas.providers import Provider
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Initial models when provider is created
|
||||
initial_models = [
|
||||
LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-refresh-add-{test_id}/gpt-4o",
|
||||
provider_name=f"test-refresh-add-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
]
|
||||
|
||||
# Updated models after refresh (includes a new model)
|
||||
updated_models = [
|
||||
LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-refresh-add-{test_id}/gpt-4o",
|
||||
provider_name=f"test-refresh-add-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
LLMConfig(
|
||||
model="gpt-4.1", # New model added by provider
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=256000,
|
||||
handle=f"test-refresh-add-{test_id}/gpt-4.1",
|
||||
provider_name=f"test-refresh-add-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
]
|
||||
|
||||
# Create mock for initial sync during provider creation
|
||||
mock_typed_provider_initial = MagicMock()
|
||||
mock_typed_provider_initial.list_llm_models_async = AsyncMock(return_value=initial_models)
|
||||
mock_typed_provider_initial.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
# Create the provider with initial models
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_initial):
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-refresh-add-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-test-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Re-fetch the provider to get the updated last_synced value
|
||||
byok_provider = await provider_manager.get_provider_async(byok_provider.id, default_user)
|
||||
|
||||
# Verify initial sync - should have 1 model
|
||||
initial_synced_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(initial_synced_models) == 1
|
||||
assert initial_synced_models[0].name == "gpt-4o"
|
||||
|
||||
initial_last_synced = byok_provider.last_synced
|
||||
assert initial_last_synced is not None # Verify sync happened during creation
|
||||
|
||||
# Create mock for refresh with updated models
|
||||
mock_typed_provider_refresh = MagicMock()
|
||||
mock_typed_provider_refresh.list_llm_models_async = AsyncMock(return_value=updated_models)
|
||||
mock_typed_provider_refresh.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
# Refresh the provider (simulating what the endpoint does)
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_refresh):
|
||||
await provider_manager._sync_default_models_for_provider(byok_provider, default_user)
|
||||
|
||||
# Verify the API was called during refresh
|
||||
mock_typed_provider_refresh.list_llm_models_async.assert_called_once()
|
||||
|
||||
# Verify new model was added
|
||||
refreshed_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(refreshed_models) == 2
|
||||
refreshed_names = {m.name for m in refreshed_models}
|
||||
assert "gpt-4o" in refreshed_names
|
||||
assert "gpt-4.1" in refreshed_names
|
||||
|
||||
# Verify last_synced was updated
|
||||
updated_provider = await provider_manager.get_provider_async(byok_provider.id, default_user)
|
||||
assert updated_provider.last_synced is not None
|
||||
assert updated_provider.last_synced >= initial_last_synced
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_byok_provider_removes_old_models(default_user, provider_manager):
|
||||
"""Test that refreshing a BYOK provider removes models no longer available from the provider API.
|
||||
|
||||
When _sync_default_models_for_provider is called (via refresh endpoint):
|
||||
1. It should fetch current models from the provider API
|
||||
2. Remove any models that are no longer available (soft delete)
|
||||
3. Keep models that are still available
|
||||
"""
|
||||
from letta.schemas.providers import Provider
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Initial models when provider is created (includes a model that will be removed)
|
||||
initial_models = [
|
||||
LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-refresh-remove-{test_id}/gpt-4o",
|
||||
provider_name=f"test-refresh-remove-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
LLMConfig(
|
||||
model="gpt-4-turbo", # This model will be deprecated/removed
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-refresh-remove-{test_id}/gpt-4-turbo",
|
||||
provider_name=f"test-refresh-remove-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
]
|
||||
|
||||
# Updated models after refresh (gpt-4-turbo is no longer available)
|
||||
updated_models = [
|
||||
LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-refresh-remove-{test_id}/gpt-4o",
|
||||
provider_name=f"test-refresh-remove-{test_id}",
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
]
|
||||
|
||||
# Create mock for initial sync during provider creation
|
||||
mock_typed_provider_initial = MagicMock()
|
||||
mock_typed_provider_initial.list_llm_models_async = AsyncMock(return_value=initial_models)
|
||||
mock_typed_provider_initial.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
# Create the provider with initial models
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_initial):
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-refresh-remove-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-test-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Verify initial sync - should have 2 models
|
||||
initial_synced_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(initial_synced_models) == 2
|
||||
initial_names = {m.name for m in initial_synced_models}
|
||||
assert "gpt-4o" in initial_names
|
||||
assert "gpt-4-turbo" in initial_names
|
||||
|
||||
# Create mock for refresh with fewer models
|
||||
mock_typed_provider_refresh = MagicMock()
|
||||
mock_typed_provider_refresh.list_llm_models_async = AsyncMock(return_value=updated_models)
|
||||
mock_typed_provider_refresh.list_embedding_models_async = AsyncMock(return_value=[])
|
||||
|
||||
# Refresh the provider (simulating what the endpoint does)
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_refresh):
|
||||
await provider_manager._sync_default_models_for_provider(byok_provider, default_user)
|
||||
|
||||
# Verify the removed model is no longer in the list
|
||||
refreshed_models = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=byok_provider.id,
|
||||
)
|
||||
assert len(refreshed_models) == 1
|
||||
assert refreshed_models[0].name == "gpt-4o"
|
||||
|
||||
# Verify gpt-4-turbo was removed (soft deleted)
|
||||
refreshed_names = {m.name for m in refreshed_models}
|
||||
assert "gpt-4-turbo" not in refreshed_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_base_provider_fails(default_user, provider_manager):
|
||||
"""Test that attempting to refresh a base provider returns an error.
|
||||
|
||||
The refresh endpoint should only work for BYOK providers, not base providers.
|
||||
Base providers are managed by environment variables and shouldn't be refreshed.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from letta.server.rest_api.routers.v1.providers import refresh_provider_models
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a base provider
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-refresh-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="", # Base providers don't store API keys
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
# Verify it's a base provider
|
||||
assert base_provider.provider_category == ProviderCategory.base
|
||||
|
||||
# Create a mock server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.provider_manager = provider_manager
|
||||
|
||||
# Create mock headers
|
||||
mock_headers = MagicMock()
|
||||
mock_headers.actor_id = default_user.id
|
||||
|
||||
# Mock get_actor_or_default_async to return our test user
|
||||
server.user_manager = MagicMock()
|
||||
server.user_manager.get_actor_or_default_async = AsyncMock(return_value=default_user)
|
||||
|
||||
# Attempt to refresh the base provider - should raise HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await refresh_provider_models(
|
||||
provider_id=base_provider.id,
|
||||
headers=mock_headers,
|
||||
server=server,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "BYOK" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, provider_manager):
|
||||
"""Test that get_model_by_handle_async returns the BYOK model when both BYOK and base providers have the same handle.
|
||||
|
||||
This tests the legacy scenario where a user has both a BYOK provider and a base provider
|
||||
with the same name (and thus models with the same handle). The BYOK model should be
|
||||
returned because it's organization-specific, while base models are global.
|
||||
"""
|
||||
test_id = generate_test_id()
|
||||
provider_name = f"test-duplicate-{test_id}"
|
||||
model_handle = f"{provider_name}/gpt-4o"
|
||||
|
||||
# Step 1: Create a base provider and sync a model for it (global, organization_id=None)
|
||||
base_provider_create = ProviderCreate(
|
||||
name=provider_name,
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="", # Base providers don't store API keys
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
assert base_provider.provider_category == ProviderCategory.base
|
||||
|
||||
# Sync a model for the base provider (global model with organization_id=None)
|
||||
base_llm_model = LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=model_handle,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=base_provider,
|
||||
llm_models=[base_llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=None, # Global model
|
||||
)
|
||||
|
||||
# Verify base model was created
|
||||
base_model = await provider_manager.get_model_by_handle_async(
|
||||
handle=model_handle,
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
)
|
||||
assert base_model is not None
|
||||
assert base_model.handle == model_handle
|
||||
assert base_model.organization_id is None # Global model
|
||||
|
||||
# Step 2: Create a BYOK provider with the same name (simulating legacy duplicate)
|
||||
# Note: In production, this is now prevented, but legacy data could have this
|
||||
# We need to bypass the name conflict check for this test (simulating legacy data)
|
||||
# Create the BYOK provider directly by manipulating the database
|
||||
from letta.orm.provider import Provider as ProviderORM
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.server.db import db_registry
|
||||
|
||||
# Create a pydantic provider first to generate an ID
|
||||
byok_pydantic_provider = PydanticProvider(
|
||||
name=provider_name, # Same name as base provider
|
||||
provider_type=ProviderType.openai,
|
||||
provider_category=ProviderCategory.byok,
|
||||
organization_id=default_user.organization_id,
|
||||
)
|
||||
byok_pydantic_provider.resolve_identifier()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True))
|
||||
await byok_provider_orm.create_async(session, actor=default_user)
|
||||
byok_provider = byok_provider_orm.to_pydantic()
|
||||
|
||||
assert byok_provider.provider_category == ProviderCategory.byok
|
||||
|
||||
# Sync a model for the BYOK provider (org-specific model)
|
||||
byok_llm_model = LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=model_handle, # Same handle as base model
|
||||
provider_name=provider_name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=byok_provider,
|
||||
llm_models=[byok_llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=default_user.organization_id, # Org-specific model
|
||||
)
|
||||
|
||||
# Step 3: Verify that get_model_by_handle_async returns the BYOK model (org-specific)
|
||||
retrieved_model = await provider_manager.get_model_by_handle_async(
|
||||
handle=model_handle,
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
)
|
||||
|
||||
assert retrieved_model is not None
|
||||
assert retrieved_model.handle == model_handle
|
||||
# The key assertion: org-specific (BYOK) model should be returned, not the global (base) model
|
||||
assert retrieved_model.organization_id == default_user.organization_id
|
||||
assert retrieved_model.provider_id == byok_provider.id
|
||||
|
||||
Reference in New Issue
Block a user