feat: Add pinecone for cloud embedding (#3160)
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""Add total_chunks and chunks_embedded to files
|
||||
|
||||
Revision ID: 47d2277e530d
|
||||
Revises: 56254216524f
|
||||
Create Date: 2025-07-03 14:32:08.539280
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "47d2277e530d"
|
||||
down_revision: Union[str, None] = "56254216524f"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("files", sa.Column("total_chunks", sa.Integer(), nullable=True))
|
||||
op.add_column("files", sa.Column("chunks_embedded", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("files", "chunks_embedded")
|
||||
op.drop_column("files", "total_chunks")
|
||||
# ### end Alembic commands ###
|
||||
@@ -364,3 +364,10 @@ REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
|
||||
MAX_FILES_OPEN = 5
|
||||
|
||||
GET_PROVIDERS_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Pinecone related fields
|
||||
PINECONE_EMBEDDING_MODEL: str = "llama-text-embed-v2"
|
||||
PINECONE_TEXT_FIELD_NAME = "chunk_text"
|
||||
PINECONE_METRIC = "cosine"
|
||||
PINECONE_CLOUD = "aws"
|
||||
PINECONE_REGION = "us-east-1"
|
||||
|
||||
@@ -65,7 +65,7 @@ async def grep_files(
|
||||
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
|
||||
|
||||
|
||||
async def semantic_search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
|
||||
async def semantic_search_files(agent_state: "AgentState", query: str, limit: int = 5) -> List["FileMetadata"]:
|
||||
"""
|
||||
Get list of most relevant chunks from any file using vector/embedding search.
|
||||
|
||||
@@ -76,6 +76,7 @@ async def semantic_search_files(agent_state: "AgentState", query: str) -> List["
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
List[FileMetadata]: List of matching files.
|
||||
|
||||
@@ -29,7 +29,6 @@ def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> d
|
||||
# "Field": Field,
|
||||
}
|
||||
env.update(globals())
|
||||
|
||||
# print("About to execute source code...")
|
||||
exec(source_code, env)
|
||||
# print("Source code executed successfully")
|
||||
|
||||
143
letta/helpers/pinecone_utils.py
Normal file
143
letta/helpers/pinecone_utils.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pinecone import PineconeAsyncio
|
||||
|
||||
from letta.constants import PINECONE_CLOUD, PINECONE_EMBEDDING_MODEL, PINECONE_METRIC, PINECONE_REGION, PINECONE_TEXT_FIELD_NAME
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def should_use_pinecone(verbose: bool = False):
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Pinecone check: enable_pinecone=%s, api_key=%s, agent_index=%s, source_index=%s",
|
||||
settings.enable_pinecone,
|
||||
bool(settings.pinecone_api_key),
|
||||
bool(settings.pinecone_agent_index),
|
||||
bool(settings.pinecone_source_index),
|
||||
)
|
||||
|
||||
return settings.enable_pinecone and settings.pinecone_api_key and settings.pinecone_agent_index and settings.pinecone_source_index
|
||||
|
||||
|
||||
async def upsert_pinecone_indices():
|
||||
from pinecone import IndexEmbed, PineconeAsyncio
|
||||
|
||||
for index_name in get_pinecone_indices():
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
if not await pc.has_index(index_name):
|
||||
await pc.create_index_for_model(
|
||||
name=index_name,
|
||||
cloud=PINECONE_CLOUD,
|
||||
region=PINECONE_REGION,
|
||||
embed=IndexEmbed(model=PINECONE_EMBEDDING_MODEL, field_map={"text": PINECONE_TEXT_FIELD_NAME}, metric=PINECONE_METRIC),
|
||||
)
|
||||
|
||||
|
||||
def get_pinecone_indices() -> List[str]:
|
||||
return [settings.pinecone_agent_index, settings.pinecone_source_index]
|
||||
|
||||
|
||||
async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User):
|
||||
records = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
record = {
|
||||
"_id": f"{file_id}_{i}",
|
||||
PINECONE_TEXT_FIELD_NAME: chunk,
|
||||
"file_id": file_id,
|
||||
"source_id": source_id,
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
return await upsert_records_to_pinecone_index(records, actor)
|
||||
|
||||
|
||||
async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
|
||||
namespace = actor.organization_id
|
||||
try:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
await dense_index.delete(
|
||||
filter={
|
||||
"file_id": {"$eq": file_id},
|
||||
},
|
||||
namespace=namespace,
|
||||
)
|
||||
except NotFoundException:
|
||||
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
|
||||
|
||||
|
||||
async def delete_source_records_from_pinecone_index(source_id: str, actor: User):
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
|
||||
namespace = actor.organization_id
|
||||
try:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
await dense_index.delete(filter={"source_id": {"$eq": source_id}}, namespace=namespace)
|
||||
except NotFoundException:
|
||||
logger.warning(f"Pinecone namespace {namespace} not found for {source_id} and {actor.organization_id}")
|
||||
|
||||
|
||||
async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
await dense_index.upsert_records(actor.organization_id, records)
|
||||
|
||||
|
||||
async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
namespace = actor.organization_id
|
||||
|
||||
try:
|
||||
# Search the dense index with reranking
|
||||
search_results = await dense_index.search(
|
||||
namespace=namespace,
|
||||
query={
|
||||
"top_k": limit,
|
||||
"inputs": {"text": query},
|
||||
"filter": filter,
|
||||
},
|
||||
rerank={"model": "bge-reranker-v2-m3", "top_n": limit, "rank_fields": [PINECONE_TEXT_FIELD_NAME]},
|
||||
)
|
||||
return search_results
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search Pinecone namespace {actor.organization_id}: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
|
||||
namespace = actor.organization_id
|
||||
try:
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
|
||||
kwargs = {"namespace": namespace, "prefix": file_id}
|
||||
if limit is not None:
|
||||
kwargs["limit"] = limit
|
||||
if pagination_token is not None:
|
||||
kwargs["pagination_token"] = pagination_token
|
||||
|
||||
try:
|
||||
result = []
|
||||
async for ids in dense_index.list(**kwargs):
|
||||
result.extend(ids)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list Pinecone namespace {actor.organization_id}: {str(e)}")
|
||||
raise e
|
||||
except NotFoundException:
|
||||
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
|
||||
@@ -60,6 +60,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
|
||||
)
|
||||
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Any error message encountered during processing.")
|
||||
total_chunks: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Total number of chunks for the file.")
|
||||
chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||
@@ -112,6 +114,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
|
||||
file_last_modified_date=self.file_last_modified_date,
|
||||
processing_status=self.processing_status,
|
||||
error_message=self.error_message,
|
||||
total_chunks=self.total_chunks,
|
||||
chunks_embedded=self.chunks_embedded,
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
is_deleted=self.is_deleted,
|
||||
|
||||
@@ -41,6 +41,8 @@ class FileMetadata(FileMetadataBase):
|
||||
description="The current processing status of the file (e.g. pending, parsing, embedding, completed, error).",
|
||||
)
|
||||
error_message: Optional[str] = Field(default=None, description="Optional error message if the file failed processing.")
|
||||
total_chunks: Optional[int] = Field(default=None, description="Total number of chunks for the file.")
|
||||
chunks_embedded: Optional[int] = Field(default=None, description="Number of chunks that have been embedded.")
|
||||
|
||||
# orm metadata, optional fields
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
|
||||
@@ -52,6 +54,10 @@ class FileMetadata(FileMetadataBase):
|
||||
default=None, description="Optional full-text content of the file; only populated on demand due to its size."
|
||||
)
|
||||
|
||||
def is_processing_terminal(self) -> bool:
|
||||
"""Check if the file processing status is in a terminal state (completed or error)."""
|
||||
return self.processing_status in (FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR)
|
||||
|
||||
|
||||
class FileAgentBase(LettaBase):
|
||||
"""Base class for the FileMetadata-⇄-Agent association schemas"""
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.__init__ import __version__ as letta_version
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.helpers.pinecone_utils import get_pinecone_indices, should_use_pinecone, upsert_pinecone_indices
|
||||
from letta.jobs.scheduler import start_scheduler_with_leader_election
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
@@ -127,6 +128,16 @@ async def lifespan(app_: FastAPI):
|
||||
db_registry.initialize_async()
|
||||
logger.info(f"[Worker {worker_id}] Database connections initialized")
|
||||
|
||||
if should_use_pinecone():
|
||||
if settings.upsert_pinecone_indices:
|
||||
logger.info(f"[Worker {worker_id}] Upserting pinecone indices: {get_pinecone_indices()}")
|
||||
await upsert_pinecone_indices()
|
||||
logger.info(f"[Worker {worker_id}] Upserted pinecone indices")
|
||||
else:
|
||||
logger.info(f"[Worker {worker_id}] Enabled pinecone")
|
||||
else:
|
||||
logger.info(f"[Worker {worker_id}] Disabled pinecone")
|
||||
|
||||
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
|
||||
global server
|
||||
try:
|
||||
|
||||
@@ -9,6 +9,12 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
|
||||
from starlette import status
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.helpers.pinecone_utils import (
|
||||
delete_file_records_from_pinecone_index,
|
||||
delete_source_records_from_pinecone_index,
|
||||
list_pinecone_index_for_files,
|
||||
should_use_pinecone,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -22,6 +28,7 @@ from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.embedder.pinecone_embedder import PineconeEmbedder
|
||||
from letta.services.file_processor.file_processor import FileProcessor
|
||||
from letta.services.file_processor.file_types import (
|
||||
get_allowed_media_types,
|
||||
@@ -163,6 +170,10 @@ async def delete_source(
|
||||
files = await server.file_manager.list_files(source_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
|
||||
if should_use_pinecone():
|
||||
logger.info(f"Deleting source {source_id} from pinecone index")
|
||||
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
|
||||
|
||||
for agent_state in agent_states:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
|
||||
@@ -326,16 +337,24 @@ async def get_file_metadata(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Verify the source exists and user has access
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
|
||||
|
||||
# Get file metadata using the file manager
|
||||
file_metadata = await server.file_manager.get_file_by_id(
|
||||
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
|
||||
)
|
||||
|
||||
if should_use_pinecone() and not file_metadata.is_processing_terminal():
|
||||
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks)
|
||||
logger.info(f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} in organization {actor.organization_id}")
|
||||
|
||||
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
|
||||
if len(ids) != file_metadata.total_chunks:
|
||||
file_status = file_metadata.processing_status
|
||||
else:
|
||||
file_status = FileProcessingStatus.COMPLETED
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
|
||||
)
|
||||
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
|
||||
@@ -364,6 +383,10 @@ async def delete_file_from_source(
|
||||
|
||||
await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
|
||||
|
||||
if should_use_pinecone():
|
||||
logger.info(f"Deleting file {file_id} from pinecone index")
|
||||
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
|
||||
|
||||
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
|
||||
if deleted_file is None:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
@@ -402,8 +425,14 @@ async def load_file_to_source_cloud(
|
||||
):
|
||||
file_processor = MistralFileParser()
|
||||
text_chunker = LlamaIndexChunker(chunk_size=embedding_config.embedding_chunk_size)
|
||||
embedder = OpenAIEmbedder(embedding_config=embedding_config)
|
||||
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
|
||||
using_pinecone = should_use_pinecone()
|
||||
if using_pinecone:
|
||||
embedder = PineconeEmbedder()
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=embedding_config)
|
||||
file_processor = FileProcessor(
|
||||
file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor, using_pinecone=using_pinecone
|
||||
)
|
||||
await file_processor.process(
|
||||
server=server, agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata
|
||||
)
|
||||
|
||||
@@ -109,15 +109,17 @@ class FileManager:
|
||||
actor: PydanticUser,
|
||||
processing_status: Optional[FileProcessingStatus] = None,
|
||||
error_message: Optional[str] = None,
|
||||
total_chunks: Optional[int] = None,
|
||||
chunks_embedded: Optional[int] = None,
|
||||
) -> PydanticFileMetadata:
|
||||
"""
|
||||
Update processing_status and/or error_message on a FileMetadata row.
|
||||
Update processing_status, error_message, total_chunks, and/or chunks_embedded on a FileMetadata row.
|
||||
|
||||
* 1st round-trip → UPDATE
|
||||
* 2nd round-trip → SELECT fresh row (same as read_async)
|
||||
"""
|
||||
|
||||
if processing_status is None and error_message is None:
|
||||
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
|
||||
raise ValueError("Nothing to update")
|
||||
|
||||
values: dict[str, object] = {"updated_at": datetime.utcnow()}
|
||||
@@ -125,6 +127,10 @@ class FileManager:
|
||||
values["processing_status"] = processing_status
|
||||
if error_message is not None:
|
||||
values["error_message"] = error_message
|
||||
if total_chunks is not None:
|
||||
values["total_chunks"] = total_chunks
|
||||
if chunks_embedded is not None:
|
||||
values["chunks_embedded"] = chunks_embedded
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Fast in-place update – no ORM hydration
|
||||
|
||||
16
letta/services/file_processor/embedder/base_embedder.py
Normal file
16
letta/services/file_processor/embedder/base_embedder.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Abstract base class for embedding generation"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
"""Generate embeddings for chunks with batching and concurrent processing"""
|
||||
@@ -9,12 +9,13 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIEmbedder:
|
||||
class OpenAIEmbedder(BaseEmbedder):
|
||||
"""OpenAI-based embedding generation"""
|
||||
|
||||
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
||||
@@ -24,6 +25,7 @@ class OpenAIEmbedder:
|
||||
else EmbeddingConfig.default_config(model_name="letta")
|
||||
)
|
||||
self.embedding_config = embedding_config or self.default_embedding_config
|
||||
self.max_concurrent_requests = 20
|
||||
|
||||
# TODO: Unify to global OpenAI client
|
||||
self.client: OpenAIClient = cast(
|
||||
@@ -34,7 +36,6 @@ class OpenAIEmbedder:
|
||||
actor=None, # Not necessary
|
||||
),
|
||||
)
|
||||
self.max_concurrent_requests = 20
|
||||
|
||||
@trace_method
|
||||
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
|
||||
|
||||
74
letta/services/file_processor/embedder/pinecone_embedder.py
Normal file
74
letta/services/file_processor/embedder/pinecone_embedder.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List
|
||||
|
||||
from letta.helpers.pinecone_utils import upsert_file_records_to_pinecone_index
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
|
||||
try:
|
||||
PINECONE_AVAILABLE = True
|
||||
except ImportError:
|
||||
PINECONE_AVAILABLE = False
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PineconeEmbedder(BaseEmbedder):
|
||||
"""Pinecone-based embedding generation"""
|
||||
|
||||
def __init__(self):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone package is not installed. Install it with: pip install pinecone")
|
||||
|
||||
super().__init__()
|
||||
|
||||
@trace_method
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
"""Generate embeddings and upsert to Pinecone, then return Passage objects"""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
logger.info(f"Upserting {len(chunks)} chunks to Pinecone using namespace {source_id}")
|
||||
log_event(
|
||||
"embedder.generation_started",
|
||||
{
|
||||
"total_chunks": len(chunks),
|
||||
"file_id": file_id,
|
||||
"source_id": source_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Upsert records to Pinecone using source_id as namespace
|
||||
try:
|
||||
await upsert_file_records_to_pinecone_index(file_id=file_id, source_id=source_id, chunks=chunks, actor=actor)
|
||||
logger.info(f"Successfully kicked off upserting {len(chunks)} records to Pinecone")
|
||||
log_event(
|
||||
"embedder.upsert_started",
|
||||
{"records_upserted": len(chunks), "namespace": source_id, "file_id": file_id},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upsert records to Pinecone: {str(e)}")
|
||||
log_event("embedder.upsert_failed", {"error": str(e), "error_type": type(e).__name__})
|
||||
raise
|
||||
|
||||
# Create Passage objects (without embeddings since Pinecone handles them)
|
||||
passages = []
|
||||
for i, text in enumerate(chunks):
|
||||
passage = Passage(
|
||||
text=text,
|
||||
file_id=file_id,
|
||||
source_id=source_id,
|
||||
embedding=None, # Pinecone handles embeddings internally
|
||||
embedding_config=None, # None
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
logger.info(f"Successfully created {len(passages)} passages")
|
||||
log_event(
|
||||
"embedder.generation_completed",
|
||||
{"passages_created": len(passages), "total_chunks_processed": len(chunks), "file_id": file_id, "source_id": source_id},
|
||||
)
|
||||
return passages
|
||||
@@ -11,7 +11,7 @@ from letta.server.server import SyncServer
|
||||
from letta.services.file_manager import FileManager
|
||||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -27,8 +27,9 @@ class FileProcessor:
|
||||
self,
|
||||
file_parser: MistralFileParser,
|
||||
text_chunker: LlamaIndexChunker,
|
||||
embedder: OpenAIEmbedder,
|
||||
embedder: BaseEmbedder,
|
||||
actor: User,
|
||||
using_pinecone: bool,
|
||||
max_file_size: int = 50 * 1024 * 1024, # 50MB default
|
||||
):
|
||||
self.file_parser = file_parser
|
||||
@@ -41,6 +42,7 @@ class FileProcessor:
|
||||
self.passage_manager = PassageManager()
|
||||
self.job_manager = JobManager()
|
||||
self.actor = actor
|
||||
self.using_pinecone = using_pinecone
|
||||
|
||||
# TODO: Factor this function out of SyncServer
|
||||
@trace_method
|
||||
@@ -109,7 +111,7 @@ class FileProcessor:
|
||||
|
||||
logger.info("Chunking extracted text")
|
||||
log_event("file_processor.chunking_started", {"filename": filename, "pages_to_process": len(ocr_response.pages)})
|
||||
all_passages = []
|
||||
all_chunks = []
|
||||
|
||||
for page in ocr_response.pages:
|
||||
chunks = self.text_chunker.chunk_text(page)
|
||||
@@ -118,24 +120,17 @@ class FileProcessor:
|
||||
log_event("file_processor.chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)})
|
||||
raise ValueError("No chunks created from text")
|
||||
|
||||
passages = await self.embedder.generate_embedded_passages(
|
||||
file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor
|
||||
)
|
||||
log_event(
|
||||
"file_processor.page_processed",
|
||||
{
|
||||
"filename": filename,
|
||||
"page_index": ocr_response.pages.index(page),
|
||||
"chunks_created": len(chunks),
|
||||
"passages_generated": len(passages),
|
||||
},
|
||||
)
|
||||
all_passages.extend(passages)
|
||||
all_chunks.extend(self.text_chunker.chunk_text(page))
|
||||
|
||||
all_passages = await self.passage_manager.create_many_source_passages_async(
|
||||
passages=all_passages, file_metadata=file_metadata, actor=self.actor
|
||||
all_passages = await self.embedder.generate_embedded_passages(
|
||||
file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor
|
||||
)
|
||||
log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
|
||||
|
||||
if not self.using_pinecone:
|
||||
all_passages = await self.passage_manager.create_many_source_passages_async(
|
||||
passages=all_passages, file_metadata=file_metadata, actor=self.actor
|
||||
)
|
||||
log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
|
||||
|
||||
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
|
||||
log_event(
|
||||
@@ -149,9 +144,14 @@ class FileProcessor:
|
||||
)
|
||||
|
||||
# update job status
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
|
||||
)
|
||||
if not self.using_pinecone:
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
|
||||
)
|
||||
else:
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, total_chunks=len(all_passages), chunks_embedded=0
|
||||
)
|
||||
|
||||
return all_passages
|
||||
|
||||
|
||||
@@ -115,10 +115,6 @@ class JobManager:
|
||||
job.completed_at = get_utc_time().replace(tzinfo=None)
|
||||
if job.callback_url:
|
||||
await self._dispatch_callback_async(job)
|
||||
else:
|
||||
logger.info(f"Job does not contain callback url: {job}")
|
||||
else:
|
||||
logger.info(f"Job update is not terminal {job_update}")
|
||||
|
||||
# Save the updated job to the database
|
||||
await job.update_async(db_session=session, actor=actor)
|
||||
|
||||
@@ -19,7 +19,6 @@ class SourceManager:
|
||||
@trace_method
|
||||
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
|
||||
"""Create a new source based on the PydanticSource schema."""
|
||||
# Try getting the source first by id
|
||||
db_source = await self.get_source_by_id(source.id, actor=actor)
|
||||
if db_source:
|
||||
return db_source
|
||||
|
||||
@@ -2,8 +2,9 @@ import asyncio
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from letta.constants import MAX_FILES_OPEN
|
||||
from letta.constants import MAX_FILES_OPEN, PINECONE_TEXT_FIELD_NAME
|
||||
from letta.functions.types import FileOpenRequest
|
||||
from letta.helpers.pinecone_utils import search_pinecone_index, should_use_pinecone
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -463,14 +464,15 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
@trace_method
|
||||
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
|
||||
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 5) -> str:
|
||||
"""
|
||||
Search for text within attached files using semantic search and return passages with their source filenames.
|
||||
Uses Pinecone if configured, otherwise falls back to traditional search.
|
||||
|
||||
Args:
|
||||
agent_state: Current agent state
|
||||
query: Search query for semantic matching
|
||||
limit: Maximum number of results to return (default: 10)
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Formatted string with search results in IDE/terminal style
|
||||
@@ -485,6 +487,110 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
|
||||
self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})")
|
||||
|
||||
# Check if Pinecone is enabled and use it if available
|
||||
if should_use_pinecone():
|
||||
return await self._search_files_pinecone(agent_state, query, limit)
|
||||
else:
|
||||
return await self._search_files_traditional(agent_state, query, limit)
|
||||
|
||||
async def _search_files_pinecone(self, agent_state: AgentState, query: str, limit: int) -> str:
|
||||
"""Search files using Pinecone vector database."""
|
||||
|
||||
# Extract unique source_ids
|
||||
# TODO: Inefficient
|
||||
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
|
||||
source_ids = [source.id for source in attached_sources]
|
||||
if not source_ids:
|
||||
return f"No valid source IDs found for attached files"
|
||||
|
||||
# Get all attached files for this agent
|
||||
file_agents = await self.files_agents_manager.list_files_for_agent(agent_id=agent_state.id, actor=self.actor)
|
||||
if not file_agents:
|
||||
return "No files are currently attached to search"
|
||||
|
||||
results = []
|
||||
total_hits = 0
|
||||
files_with_matches = {}
|
||||
|
||||
try:
|
||||
filter = {"source_id": {"$in": source_ids}}
|
||||
search_results = await search_pinecone_index(query, limit, filter, self.actor)
|
||||
|
||||
# Process search results
|
||||
if "result" in search_results and "hits" in search_results["result"]:
|
||||
for hit in search_results["result"]["hits"]:
|
||||
if total_hits >= limit:
|
||||
break
|
||||
|
||||
total_hits += 1
|
||||
|
||||
# Extract hit information
|
||||
hit_id = hit.get("_id", "unknown")
|
||||
score = hit.get("_score", 0.0)
|
||||
fields = hit.get("fields", {})
|
||||
text = fields.get(PINECONE_TEXT_FIELD_NAME, "")
|
||||
file_id = fields.get("file_id", "")
|
||||
|
||||
# Find corresponding file name
|
||||
file_name = "Unknown File"
|
||||
for fa in file_agents:
|
||||
if fa.file_id == file_id:
|
||||
file_name = fa.file_name
|
||||
break
|
||||
|
||||
# Group by file name
|
||||
if file_name not in files_with_matches:
|
||||
files_with_matches[file_name] = []
|
||||
files_with_matches[file_name].append({"text": text, "score": score, "hit_id": hit_id})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Pinecone search failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
if not files_with_matches:
|
||||
return f"No semantic matches found in Pinecone for query: '{query}'"
|
||||
|
||||
# Format results
|
||||
passage_num = 0
|
||||
for file_name, matches in files_with_matches.items():
|
||||
for match in matches:
|
||||
passage_num += 1
|
||||
|
||||
# Format each passage with terminal-style header
|
||||
score_display = f"(score: {match['score']:.3f})"
|
||||
passage_header = f"\n=== {file_name} (passage #{passage_num}) {score_display} ==="
|
||||
|
||||
# Format the passage text
|
||||
passage_text = match["text"].strip()
|
||||
lines = passage_text.splitlines()
|
||||
formatted_lines = []
|
||||
for line in lines[:20]: # Limit to first 20 lines per passage
|
||||
formatted_lines.append(f" {line}")
|
||||
|
||||
if len(lines) > 20:
|
||||
formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]")
|
||||
|
||||
passage_content = "\n".join(formatted_lines)
|
||||
results.append(f"{passage_header}\n{passage_content}")
|
||||
|
||||
# Mark access for files that had matches
|
||||
if files_with_matches:
|
||||
matched_file_names = [name for name in files_with_matches.keys() if name != "Unknown File"]
|
||||
if matched_file_names:
|
||||
await self.files_agents_manager.mark_access_bulk(agent_id=agent_state.id, file_names=matched_file_names, actor=self.actor)
|
||||
|
||||
# Create summary header
|
||||
file_count = len(files_with_matches)
|
||||
summary = f"Found {total_hits} Pinecone matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'"
|
||||
|
||||
# Combine all results
|
||||
formatted_results = [summary, "=" * len(summary)] + results
|
||||
|
||||
self.logger.info(f"Pinecone search completed: {total_hits} matches across {file_count} files")
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
|
||||
"""Traditional search using existing passage manager."""
|
||||
# Get semantic search results
|
||||
passages = await self.agent_manager.list_source_passages_async(
|
||||
actor=self.actor,
|
||||
|
||||
@@ -253,6 +253,13 @@ class Settings(BaseSettings):
|
||||
llm_request_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM requests in seconds")
|
||||
llm_stream_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds")
|
||||
|
||||
# For embeddings
|
||||
enable_pinecone: bool = False
|
||||
pinecone_api_key: Optional[str] = None
|
||||
pinecone_source_index: Optional[str] = "sources"
|
||||
pinecone_agent_index: Optional[str] = "recall"
|
||||
upsert_pinecone_indices: bool = False
|
||||
|
||||
@property
|
||||
def letta_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
594
poetry.lock
generated
594
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -98,6 +98,7 @@ redis = {version = "^6.2.0", optional = true}
|
||||
structlog = "^25.4.0"
|
||||
certifi = "^2025.6.15"
|
||||
aioboto3 = {version = "^14.3.0", optional = true}
|
||||
pinecone = {extras = ["asyncio"], version = "^7.3.0"}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -119,6 +120,7 @@ black = "^24.4.2"
|
||||
ipykernel = "^6.29.5"
|
||||
ipdb = "^0.13.13"
|
||||
pytest-mock = "^3.14.0"
|
||||
pinecone = "^7.3.0"
|
||||
|
||||
|
||||
[tool.poetry.group."dev,tests".dependencies]
|
||||
|
||||
@@ -28,6 +28,24 @@ def disable_e2b_api_key() -> Generator[None, None, None]:
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_pinecone() -> Generator[None, None, None]:
|
||||
"""
|
||||
Temporarily disables Pinecone by setting `settings.enable_pinecone` to False
|
||||
and `settings.pinecone_api_key` to None for the duration of the test.
|
||||
Restores the original values afterward.
|
||||
"""
|
||||
from letta.settings import settings
|
||||
|
||||
original_enable_pinecone = settings.enable_pinecone
|
||||
original_pinecone_api_key = settings.pinecone_api_key
|
||||
settings.enable_pinecone = False
|
||||
settings.pinecone_api_key = None
|
||||
yield
|
||||
settings.enable_pinecone = original_enable_pinecone
|
||||
settings.pinecone_api_key = original_pinecone_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_e2b_key_is_set():
|
||||
from letta.settings import tool_settings
|
||||
|
||||
@@ -3320,7 +3320,7 @@ async def test_update_tool_pip_requirements(server: SyncServer, print_tool, defa
|
||||
# Add pip requirements to existing tool
|
||||
pip_reqs = [
|
||||
PipRequirement(name="pandas", version="1.5.0"),
|
||||
PipRequirement(name="matplotlib"),
|
||||
PipRequirement(name="sumy"),
|
||||
]
|
||||
|
||||
tool_update = ToolUpdate(pip_requirements=pip_reqs)
|
||||
@@ -3334,7 +3334,7 @@ async def test_update_tool_pip_requirements(server: SyncServer, print_tool, defa
|
||||
assert len(updated_tool.pip_requirements) == 2
|
||||
assert updated_tool.pip_requirements[0].name == "pandas"
|
||||
assert updated_tool.pip_requirements[0].version == "1.5.0"
|
||||
assert updated_tool.pip_requirements[1].name == "matplotlib"
|
||||
assert updated_tool.pip_requirements[1].name == "sumy"
|
||||
assert updated_tool.pip_requirements[1].version is None
|
||||
|
||||
|
||||
@@ -5218,6 +5218,41 @@ async def test_update_file_status_error_only(server, default_user, default_sourc
|
||||
assert updated.processing_status == FileProcessingStatus.PENDING # default from creation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_file_status_with_chunks(server, default_user, default_source):
|
||||
"""Update chunk progress fields along with status."""
|
||||
meta = PydanticFileMetadata(
|
||||
file_name="chunks_test.txt",
|
||||
file_path="/tmp/chunks_test.txt",
|
||||
file_type="text/plain",
|
||||
file_size=500,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
created = await server.file_manager.create_file(file_metadata=meta, actor=default_user)
|
||||
|
||||
# Update with chunk progress
|
||||
updated = await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
processing_status=FileProcessingStatus.EMBEDDING,
|
||||
total_chunks=100,
|
||||
chunks_embedded=50,
|
||||
)
|
||||
assert updated.processing_status == FileProcessingStatus.EMBEDDING
|
||||
assert updated.total_chunks == 100
|
||||
assert updated.chunks_embedded == 50
|
||||
|
||||
# Update only chunk progress
|
||||
updated = await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
chunks_embedded=100,
|
||||
)
|
||||
assert updated.chunks_embedded == 100
|
||||
assert updated.total_chunks == 100 # unchanged
|
||||
assert updated.processing_status == FileProcessingStatus.EMBEDDING # unchanged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_file_content_basic(server: SyncServer, default_user, default_source, async_session):
|
||||
"""Test creating and updating file content with upsert_file_content()."""
|
||||
|
||||
@@ -9,9 +9,10 @@ from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client.types import AgentState
|
||||
|
||||
from letta.constants import FILES_TOOLS
|
||||
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
@@ -49,7 +50,7 @@ def client() -> LettaSDKClient:
|
||||
yield client
|
||||
|
||||
|
||||
def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 30):
|
||||
def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 60):
|
||||
"""Helper function to upload a file and wait for processing to complete"""
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f)
|
||||
@@ -70,7 +71,7 @@ def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str,
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_state(client: LettaSDKClient):
|
||||
def agent_state(disable_pinecone, client: LettaSDKClient):
|
||||
open_file_tool = client.tools.list(name="open_files")[0]
|
||||
search_files_tool = client.tools.list(name="semantic_search_files")[0]
|
||||
grep_tool = client.tools.list(name="grep_files")[0]
|
||||
@@ -93,7 +94,7 @@ def agent_state(client: LettaSDKClient):
|
||||
# Tests
|
||||
|
||||
|
||||
def test_auto_attach_detach_files_tools(client: LettaSDKClient):
|
||||
def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test automatic attachment and detachment of file tools when managing agent sources."""
|
||||
# Create agent with basic configuration
|
||||
agent = client.agents.create(
|
||||
@@ -164,6 +165,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
|
||||
],
|
||||
)
|
||||
def test_file_upload_creates_source_blocks_correctly(
|
||||
disable_pinecone,
|
||||
client: LettaSDKClient,
|
||||
agent_state: AgentState,
|
||||
file_path: str,
|
||||
@@ -204,7 +206,7 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
@@ -240,7 +242,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
|
||||
assert not any("test" in b.value for b in blocks)
|
||||
|
||||
|
||||
def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
@@ -270,7 +272,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
|
||||
assert not any("test" in b.value for b in blocks)
|
||||
|
||||
|
||||
def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -377,7 +379,7 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat
|
||||
print("✓ File successfully opened with different range - content differs as expected")
|
||||
|
||||
|
||||
def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -423,7 +425,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -465,7 +467,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -517,7 +519,7 @@ def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state:
|
||||
assert "513:" in tool_return_message.tool_return
|
||||
|
||||
|
||||
def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: LettaSDKClient):
|
||||
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that creating an agent with source_ids parameter correctly creates source blocks."""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
@@ -560,7 +562,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le
|
||||
assert file_tools == set(FILES_TOOLS)
|
||||
|
||||
|
||||
def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -623,7 +625,7 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
|
||||
)
|
||||
|
||||
|
||||
def test_duplicate_file_renaming(client: LettaSDKClient):
|
||||
def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
|
||||
@@ -662,7 +664,7 @@ def test_duplicate_file_renaming(client: LettaSDKClient):
|
||||
print(f" File {i+1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
|
||||
|
||||
|
||||
def test_open_files_schema_descriptions(client: LettaSDKClient):
|
||||
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that open_files tool schema contains correct descriptions from docstring"""
|
||||
|
||||
# Get the open_files tool
|
||||
@@ -743,3 +745,132 @@ def test_open_files_schema_descriptions(client: LettaSDKClient):
|
||||
expected_length_desc = "Optional number of lines to view from offset (inclusive). If not specified, views to end of file."
|
||||
assert length_prop["description"] == expected_length_desc
|
||||
assert length_prop["type"] == "integer"
|
||||
|
||||
|
||||
# --- Pinecone Tests ---
|
||||
|
||||
|
||||
def test_pinecone_search_files_tool(client: LettaSDKClient):
|
||||
"""Test that search_files tool uses Pinecone when enabled"""
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
|
||||
if not should_use_pinecone(verbose=True):
|
||||
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
|
||||
|
||||
print("Testing Pinecone search_files tool functionality")
|
||||
|
||||
# Create agent with file tools
|
||||
agent = client.agents.create(
|
||||
name="test_pinecone_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="username: testuser"),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Create source and attach to agent
|
||||
source = client.sources.create(name="test_pinecone_source", embedding="openai/text-embedding-3-small")
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
|
||||
|
||||
# Upload a file with searchable content
|
||||
file_path = "tests/data/long_test.txt"
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# Test semantic search using Pinecone
|
||||
search_response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Use the semantic_search_files tool to search for 'electoral history' in the files.")],
|
||||
)
|
||||
|
||||
# Verify tool was called successfully
|
||||
tool_calls = [msg for msg in search_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
|
||||
|
||||
# Verify tool returned results
|
||||
tool_returns = [msg for msg in search_response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_returns) > 0, "No tool returns found"
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
# Check that results contain expected content
|
||||
search_results = tool_returns[0].tool_return
|
||||
print(search_results)
|
||||
assert (
|
||||
"electoral" in search_results.lower() or "history" in search_results.lower()
|
||||
), f"Search results should contain relevant content: {search_results}"
|
||||
|
||||
|
||||
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
|
||||
"""Test that file and source deletion removes records from Pinecone"""
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||
|
||||
if not should_use_pinecone():
|
||||
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
|
||||
|
||||
print("Testing Pinecone file and source deletion lifecycle")
|
||||
|
||||
# Create source
|
||||
source = client.sources.create(name="test_lifecycle_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload multiple files and wait for processing
|
||||
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
uploaded_files = []
|
||||
for file_path in file_paths:
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
uploaded_files.append(file_metadata)
|
||||
|
||||
# Get temp user for Pinecone operations
|
||||
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
|
||||
|
||||
# Test file-level deletion first
|
||||
if len(uploaded_files) > 1:
|
||||
file_to_delete = uploaded_files[0]
|
||||
|
||||
# Check records for the specific file using list function
|
||||
records_before = asyncio.run(list_pinecone_index_for_files(file_to_delete.id, user))
|
||||
print(f"Found {len(records_before)} records for file before deletion")
|
||||
|
||||
# Delete the file
|
||||
client.sources.files.delete(source_id=source.id, file_id=file_to_delete.id)
|
||||
|
||||
# Allow time for deletion to propagate
|
||||
time.sleep(2)
|
||||
|
||||
# Verify file records are removed
|
||||
records_after = asyncio.run(list_pinecone_index_for_files(file_to_delete.id, user))
|
||||
print(f"Found {len(records_after)} records for file after deletion")
|
||||
|
||||
assert len(records_after) == 0, f"File records should be removed from Pinecone after deletion, but found {len(records_after)}"
|
||||
|
||||
# Test source-level deletion - check remaining files
|
||||
# Check records for remaining files
|
||||
remaining_records = []
|
||||
for file_metadata in uploaded_files[1:]: # Skip the already deleted file
|
||||
file_records = asyncio.run(list_pinecone_index_for_files(file_metadata.id, user))
|
||||
remaining_records.extend(file_records)
|
||||
|
||||
records_before = len(remaining_records)
|
||||
print(f"Found {records_before} records for remaining files before source deletion")
|
||||
|
||||
# Delete the entire source
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
# Allow time for deletion to propagate
|
||||
time.sleep(3)
|
||||
|
||||
# Verify all remaining file records are removed
|
||||
records_after = []
|
||||
for file_metadata in uploaded_files[1:]:
|
||||
file_records = asyncio.run(list_pinecone_index_for_files(file_metadata.id, user))
|
||||
records_after.extend(file_records)
|
||||
|
||||
print(f"Found {len(records_after)} records for files after source deletion")
|
||||
|
||||
assert (
|
||||
len(records_after) == 0
|
||||
), f"All source records should be removed from Pinecone after source deletion, but found {len(records_after)}"
|
||||
|
||||
print("✓ Pinecone lifecycle verified - namespace is clean after source deletion")
|
||||
|
||||
Reference in New Issue
Block a user