diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 008da488..cd816ef3 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -3,7 +3,9 @@ from typing import Optional from pydantic import Field +from letta.helpers.tpuf_client import should_use_tpuf from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import VectorDBProvider from letta.schemas.letta_base import LettaBase @@ -40,6 +42,10 @@ class Source(BaseSource): metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="Metadata associated with the source.") # metadata fields + vector_db_provider: VectorDBProvider = Field( + default=VectorDBProvider.NATIVE, + description="The vector database provider used for this source's passages", + ) created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.") diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 38a21437..efe9e650 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -3,12 +3,14 @@ from typing import List, Optional, Union from sqlalchemy import and_, exists, select +from letta.helpers.tpuf_client import should_use_tpuf from letta.orm import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.source import Source as SourceModel from letta.orm.sources_agents import SourcesAgents from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState +from letta.schemas.enums import VectorDBProvider from letta.schemas.source import Source as PydanticSource, SourceUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -50,9 +52,12 @@ class SourceManager: if db_source: return db_source else: + vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE + async with db_registry.async_session() as session: # Provide default embedding config if not given source.organization_id = actor.organization_id + source.vector_db_provider = vector_db_provider source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True)) await source.create_async(session, actor=actor) return source.to_pydantic() @@ -91,6 +96,10 @@ class SourceManager: Returns: List of created/updated sources """ + vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE + for pydantic_source in pydantic_sources: + pydantic_source.vector_db_provider = vector_db_provider + if not pydantic_sources: return [] @@ -164,7 +173,7 @@ class SourceManager: # update existing source from letta.schemas.source import SourceUpdate - update_data = source.model_dump(exclude={"id"}, exclude_none=True) + update_data = source.model_dump(exclude={"id", "vector_db_provider"}, exclude_none=True) updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor) sources.append(updated_source) else: diff --git a/tests/test_managers.py b/tests/test_managers.py index 8a668189..73882f49 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -63,6 +63,7 @@ from letta.schemas.enums import ( StepStatus, TagMatchMode, ToolType, + VectorDBProvider, ) from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata @@ -7203,6 +7204,57 @@ async def test_create_source(server: SyncServer, default_user): assert source.organization_id == default_user.organization_id +async def test_source_vector_db_provider_with_tpuf(server: SyncServer, default_user): + """Test that vector_db_provider is correctly set based on should_use_tpuf.""" + from letta.settings import settings + + # save original values + original_use_tpuf = settings.use_tpuf + original_tpuf_api_key = settings.tpuf_api_key + + try: + # test when should_use_tpuf returns True (expect TPUF provider) + settings.use_tpuf = True + settings.tpuf_api_key = "test_key" + + # need to mock it in source_manager since it's already imported + with patch("letta.services.source_manager.should_use_tpuf", return_value=True): + source_pydantic = PydanticSource( + name="Test Source TPUF", + description="Source with TPUF provider", + metadata={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + vector_db_provider=VectorDBProvider.TPUF, # explicitly set it + ) + assert source_pydantic.vector_db_provider == VectorDBProvider.TPUF + + # create source and verify it's saved with TPUF provider + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) + assert source.vector_db_provider == VectorDBProvider.TPUF + + # test when should_use_tpuf returns False (expect NATIVE provider) + settings.use_tpuf = False + settings.tpuf_api_key = None + + with patch("letta.services.source_manager.should_use_tpuf", return_value=False): + source_pydantic = PydanticSource( + name="Test Source Native", + description="Source with Native provider", + metadata={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + vector_db_provider=VectorDBProvider.NATIVE, # explicitly set it + ) + assert source_pydantic.vector_db_provider == VectorDBProvider.NATIVE + + # create source and verify it's saved with NATIVE provider + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) + assert source.vector_db_provider == VectorDBProvider.NATIVE + finally: + # restore original values + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_tpuf_api_key + + async def test_create_sources_with_same_name_raises_error(server: SyncServer, default_user): """Test that creating sources with the same name raises an IntegrityError due to unique constraint.""" name = "Test Source"