feat: Surface vector db in source object [LET-4257] (#4479)
* Surface vector db * Change default * Fern autogen * Fix sqlite test managers
This commit is contained in:
@@ -3,7 +3,9 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -40,6 +42,10 @@ class Source(BaseSource):
|
||||
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="Metadata associated with the source.")
|
||||
|
||||
# metadata fields
|
||||
vector_db_provider: VectorDBProvider = Field(
|
||||
default=VectorDBProvider.NATIVE,
|
||||
description="The vector database provider used for this source's passages",
|
||||
)
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.")
|
||||
|
||||
@@ -3,12 +3,14 @@ from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import and_, exists, select
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.source import Source as SourceModel
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.source import Source as PydanticSource, SourceUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -50,9 +52,12 @@ class SourceManager:
|
||||
if db_source:
|
||||
return db_source
|
||||
else:
|
||||
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Provide default embedding config if not given
|
||||
source.organization_id = actor.organization_id
|
||||
source.vector_db_provider = vector_db_provider
|
||||
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
|
||||
await source.create_async(session, actor=actor)
|
||||
return source.to_pydantic()
|
||||
@@ -91,6 +96,10 @@ class SourceManager:
|
||||
Returns:
|
||||
List of created/updated sources
|
||||
"""
|
||||
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
|
||||
for pydantic_source in pydantic_sources:
|
||||
pydantic_source.vector_db_provider = vector_db_provider
|
||||
|
||||
if not pydantic_sources:
|
||||
return []
|
||||
|
||||
@@ -164,7 +173,7 @@ class SourceManager:
|
||||
# update existing source
|
||||
from letta.schemas.source import SourceUpdate
|
||||
|
||||
update_data = source.model_dump(exclude={"id"}, exclude_none=True)
|
||||
update_data = source.model_dump(exclude={"id", "vector_db_provider"}, exclude_none=True)
|
||||
updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor)
|
||||
sources.append(updated_source)
|
||||
else:
|
||||
|
||||
@@ -63,6 +63,7 @@ from letta.schemas.enums import (
|
||||
StepStatus,
|
||||
TagMatchMode,
|
||||
ToolType,
|
||||
VectorDBProvider,
|
||||
)
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata
|
||||
@@ -7203,6 +7204,57 @@ async def test_create_source(server: SyncServer, default_user):
|
||||
assert source.organization_id == default_user.organization_id
|
||||
|
||||
|
||||
async def test_source_vector_db_provider_with_tpuf(server: SyncServer, default_user):
|
||||
"""Test that vector_db_provider is correctly set based on should_use_tpuf."""
|
||||
from letta.settings import settings
|
||||
|
||||
# save original values
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_tpuf_api_key = settings.tpuf_api_key
|
||||
|
||||
try:
|
||||
# test when should_use_tpuf returns True (expect TPUF provider)
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = "test_key"
|
||||
|
||||
# need to mock it in source_manager since it's already imported
|
||||
with patch("letta.services.source_manager.should_use_tpuf", return_value=True):
|
||||
source_pydantic = PydanticSource(
|
||||
name="Test Source TPUF",
|
||||
description="Source with TPUF provider",
|
||||
metadata={"type": "test"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
vector_db_provider=VectorDBProvider.TPUF, # explicitly set it
|
||||
)
|
||||
assert source_pydantic.vector_db_provider == VectorDBProvider.TPUF
|
||||
|
||||
# create source and verify it's saved with TPUF provider
|
||||
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
assert source.vector_db_provider == VectorDBProvider.TPUF
|
||||
|
||||
# test when should_use_tpuf returns False (expect NATIVE provider)
|
||||
settings.use_tpuf = False
|
||||
settings.tpuf_api_key = None
|
||||
|
||||
with patch("letta.services.source_manager.should_use_tpuf", return_value=False):
|
||||
source_pydantic = PydanticSource(
|
||||
name="Test Source Native",
|
||||
description="Source with Native provider",
|
||||
metadata={"type": "test"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
vector_db_provider=VectorDBProvider.NATIVE, # explicitly set it
|
||||
)
|
||||
assert source_pydantic.vector_db_provider == VectorDBProvider.NATIVE
|
||||
|
||||
# create source and verify it's saved with NATIVE provider
|
||||
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
assert source.vector_db_provider == VectorDBProvider.NATIVE
|
||||
finally:
|
||||
# restore original values
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_tpuf_api_key
|
||||
|
||||
|
||||
async def test_create_sources_with_same_name_raises_error(server: SyncServer, default_user):
|
||||
"""Test that creating sources with the same name raises an IntegrityError due to unique constraint."""
|
||||
name = "Test Source"
|
||||
|
||||
Reference in New Issue
Block a user