feat: Surface vector db in source object [LET-4257] (#4479)

* Surface vector db

* Change default

* Fern autogen

* Fix sqlite test managers
This commit is contained in:
Matthew Zhou
2025-09-08 18:08:55 -07:00
committed by GitHub
parent 6a9eb16f4e
commit 20881b3383
3 changed files with 68 additions and 1 deletions

View File

@@ -3,7 +3,9 @@ from typing import Optional
from pydantic import Field
from letta.helpers.tpuf_client import should_use_tpuf
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import VectorDBProvider
from letta.schemas.letta_base import LettaBase
@@ -40,6 +42,10 @@ class Source(BaseSource):
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="Metadata associated with the source.")
# metadata fields
vector_db_provider: VectorDBProvider = Field(
default=VectorDBProvider.NATIVE,
description="The vector database provider used for this source's passages",
)
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.")

View File

@@ -3,12 +3,14 @@ from typing import List, Optional, Union
from sqlalchemy import and_, exists, select
from letta.helpers.tpuf_client import should_use_tpuf
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.source import Source as SourceModel
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.enums import VectorDBProvider
from letta.schemas.source import Source as PydanticSource, SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
@@ -50,9 +52,12 @@ class SourceManager:
if db_source:
return db_source
else:
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
async with db_registry.async_session() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source.vector_db_provider = vector_db_provider
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
await source.create_async(session, actor=actor)
return source.to_pydantic()
@@ -91,6 +96,10 @@ class SourceManager:
Returns:
List of created/updated sources
"""
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
for pydantic_source in pydantic_sources:
pydantic_source.vector_db_provider = vector_db_provider
if not pydantic_sources:
return []
@@ -164,7 +173,7 @@ class SourceManager:
# update existing source
from letta.schemas.source import SourceUpdate
update_data = source.model_dump(exclude={"id"}, exclude_none=True)
update_data = source.model_dump(exclude={"id", "vector_db_provider"}, exclude_none=True)
updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor)
sources.append(updated_source)
else:

View File

@@ -63,6 +63,7 @@ from letta.schemas.enums import (
StepStatus,
TagMatchMode,
ToolType,
VectorDBProvider,
)
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata
@@ -7203,6 +7204,57 @@ async def test_create_source(server: SyncServer, default_user):
assert source.organization_id == default_user.organization_id
async def test_source_vector_db_provider_with_tpuf(server: SyncServer, default_user):
"""Test that vector_db_provider is correctly set based on should_use_tpuf."""
from letta.settings import settings
# save original values
original_use_tpuf = settings.use_tpuf
original_tpuf_api_key = settings.tpuf_api_key
try:
# test when should_use_tpuf returns True (expect TPUF provider)
settings.use_tpuf = True
settings.tpuf_api_key = "test_key"
# need to mock it in source_manager since it's already imported
with patch("letta.services.source_manager.should_use_tpuf", return_value=True):
source_pydantic = PydanticSource(
name="Test Source TPUF",
description="Source with TPUF provider",
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
vector_db_provider=VectorDBProvider.TPUF, # explicitly set it
)
assert source_pydantic.vector_db_provider == VectorDBProvider.TPUF
# create source and verify it's saved with TPUF provider
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
assert source.vector_db_provider == VectorDBProvider.TPUF
# test when should_use_tpuf returns False (expect NATIVE provider)
settings.use_tpuf = False
settings.tpuf_api_key = None
with patch("letta.services.source_manager.should_use_tpuf", return_value=False):
source_pydantic = PydanticSource(
name="Test Source Native",
description="Source with Native provider",
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
vector_db_provider=VectorDBProvider.NATIVE, # explicitly set it
)
assert source_pydantic.vector_db_provider == VectorDBProvider.NATIVE
# create source and verify it's saved with NATIVE provider
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
assert source.vector_db_provider == VectorDBProvider.NATIVE
finally:
# restore original values
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_tpuf_api_key
async def test_create_sources_with_same_name_raises_error(server: SyncServer, default_user):
"""Test that creating sources with the same name raises an IntegrityError due to unique constraint."""
name = "Test Source"