feat: Add pinecone for cloud embedding (#3160)

This commit is contained in:
Matthew Zhou
2025-07-03 22:37:55 -07:00
committed by GitHub
parent 006e9c9100
commit 9605d1f79c
23 changed files with 810 additions and 526 deletions

View File

@@ -0,0 +1,33 @@
"""Add total_chunks and chunks_embedded to files
Revision ID: 47d2277e530d
Revises: 56254216524f
Create Date: 2025-07-03 14:32:08.539280
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "47d2277e530d"
down_revision: Union[str, None] = "56254216524f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("total_chunks", sa.Integer(), nullable=True))
op.add_column("files", sa.Column("chunks_embedded", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "chunks_embedded")
op.drop_column("files", "total_chunks")
# ### end Alembic commands ###

View File

@@ -364,3 +364,10 @@ REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
MAX_FILES_OPEN = 5
GET_PROVIDERS_TIMEOUT_SECONDS = 10
# Pinecone related fields
PINECONE_EMBEDDING_MODEL: str = "llama-text-embed-v2"
PINECONE_TEXT_FIELD_NAME = "chunk_text"
PINECONE_METRIC = "cosine"
PINECONE_CLOUD = "aws"
PINECONE_REGION = "us-east-1"

View File

@@ -65,7 +65,7 @@ async def grep_files(
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
async def semantic_search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
async def semantic_search_files(agent_state: "AgentState", query: str, limit: int = 5) -> List["FileMetadata"]:
"""
Get list of most relevant chunks from any file using vector/embedding search.
@@ -76,6 +76,7 @@ async def semantic_search_files(agent_state: "AgentState", query: str) -> List["
Args:
query (str): The search query.
limit: Maximum number of results to return (default: 5)
Returns:
List[FileMetadata]: List of matching files.

View File

@@ -29,7 +29,6 @@ def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> d
# "Field": Field,
}
env.update(globals())
# print("About to execute source code...")
exec(source_code, env)
# print("Source code executed successfully")

View File

@@ -0,0 +1,143 @@
from typing import Any, Dict, List
from pinecone import PineconeAsyncio
from letta.constants import PINECONE_CLOUD, PINECONE_EMBEDDING_MODEL, PINECONE_METRIC, PINECONE_REGION, PINECONE_TEXT_FIELD_NAME
from letta.log import get_logger
from letta.schemas.user import User
from letta.settings import settings
logger = get_logger(__name__)
def should_use_pinecone(verbose: bool = False):
if verbose:
logger.info(
"Pinecone check: enable_pinecone=%s, api_key=%s, agent_index=%s, source_index=%s",
settings.enable_pinecone,
bool(settings.pinecone_api_key),
bool(settings.pinecone_agent_index),
bool(settings.pinecone_source_index),
)
return settings.enable_pinecone and settings.pinecone_api_key and settings.pinecone_agent_index and settings.pinecone_source_index
async def upsert_pinecone_indices():
from pinecone import IndexEmbed, PineconeAsyncio
for index_name in get_pinecone_indices():
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
if not await pc.has_index(index_name):
await pc.create_index_for_model(
name=index_name,
cloud=PINECONE_CLOUD,
region=PINECONE_REGION,
embed=IndexEmbed(model=PINECONE_EMBEDDING_MODEL, field_map={"text": PINECONE_TEXT_FIELD_NAME}, metric=PINECONE_METRIC),
)
def get_pinecone_indices() -> List[str]:
return [settings.pinecone_agent_index, settings.pinecone_source_index]
async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User):
records = []
for i, chunk in enumerate(chunks):
record = {
"_id": f"{file_id}_{i}",
PINECONE_TEXT_FIELD_NAME: chunk,
"file_id": file_id,
"source_id": source_id,
}
records.append(record)
return await upsert_records_to_pinecone_index(records, actor)
async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
from pinecone.exceptions.exceptions import NotFoundException
namespace = actor.organization_id
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.delete(
filter={
"file_id": {"$eq": file_id},
},
namespace=namespace,
)
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
async def delete_source_records_from_pinecone_index(source_id: str, actor: User):
from pinecone.exceptions.exceptions import NotFoundException
namespace = actor.organization_id
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.delete(filter={"source_id": {"$eq": source_id}}, namespace=namespace)
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {source_id} and {actor.organization_id}")
async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.upsert_records(actor.organization_id, records)
async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
namespace = actor.organization_id
try:
# Search the dense index with reranking
search_results = await dense_index.search(
namespace=namespace,
query={
"top_k": limit,
"inputs": {"text": query},
"filter": filter,
},
rerank={"model": "bge-reranker-v2-m3", "top_n": limit, "rank_fields": [PINECONE_TEXT_FIELD_NAME]},
)
return search_results
except Exception as e:
logger.warning(f"Failed to search Pinecone namespace {actor.organization_id}: {str(e)}")
raise e
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
from pinecone.exceptions.exceptions import NotFoundException
namespace = actor.organization_id
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
kwargs = {"namespace": namespace, "prefix": file_id}
if limit is not None:
kwargs["limit"] = limit
if pagination_token is not None:
kwargs["pagination_token"] = pagination_token
try:
result = []
async for ids in dense_index.list(**kwargs):
result.extend(ids)
return result
except Exception as e:
logger.warning(f"Failed to list Pinecone namespace {actor.organization_id}: {str(e)}")
raise e
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")

View File

@@ -60,6 +60,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Any error message encountered during processing.")
total_chunks: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Total number of chunks for the file.")
chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
@@ -112,6 +114,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
file_last_modified_date=self.file_last_modified_date,
processing_status=self.processing_status,
error_message=self.error_message,
total_chunks=self.total_chunks,
chunks_embedded=self.chunks_embedded,
created_at=self.created_at,
updated_at=self.updated_at,
is_deleted=self.is_deleted,

View File

@@ -41,6 +41,8 @@ class FileMetadata(FileMetadataBase):
description="The current processing status of the file (e.g. pending, parsing, embedding, completed, error).",
)
error_message: Optional[str] = Field(default=None, description="Optional error message if the file failed processing.")
total_chunks: Optional[int] = Field(default=None, description="Total number of chunks for the file.")
chunks_embedded: Optional[int] = Field(default=None, description="Number of chunks that have been embedded.")
# orm metadata, optional fields
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
@@ -52,6 +54,10 @@ class FileMetadata(FileMetadataBase):
default=None, description="Optional full-text content of the file; only populated on demand due to its size."
)
def is_processing_terminal(self) -> bool:
"""Check if the file processing status is in a terminal state (completed or error)."""
return self.processing_status in (FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR)
class FileAgentBase(LettaBase):
"""Base class for the FileMetadata-⇄-Agent association schemas"""

View File

@@ -17,6 +17,7 @@ from letta.__init__ import __version__ as letta_version
from letta.agents.exceptions import IncompatibleAgentType
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.helpers.pinecone_utils import get_pinecone_indices, should_use_pinecone, upsert_pinecone_indices
from letta.jobs.scheduler import start_scheduler_with_leader_election
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
@@ -127,6 +128,16 @@ async def lifespan(app_: FastAPI):
db_registry.initialize_async()
logger.info(f"[Worker {worker_id}] Database connections initialized")
if should_use_pinecone():
if settings.upsert_pinecone_indices:
logger.info(f"[Worker {worker_id}] Upserting pinecone indices: {get_pinecone_indices()}")
await upsert_pinecone_indices()
logger.info(f"[Worker {worker_id}] Upserted pinecone indices")
else:
logger.info(f"[Worker {worker_id}] Enabled pinecone")
else:
logger.info(f"[Worker {worker_id}] Disabled pinecone")
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
global server
try:

View File

@@ -9,6 +9,12 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
from starlette import status
import letta.constants as constants
from letta.helpers.pinecone_utils import (
delete_file_records_from_pinecone_index,
delete_source_records_from_pinecone_index,
list_pinecone_index_for_files,
should_use_pinecone,
)
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -22,6 +28,7 @@ from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
from letta.services.file_processor.embedder.pinecone_embedder import PineconeEmbedder
from letta.services.file_processor.file_processor import FileProcessor
from letta.services.file_processor.file_types import (
get_allowed_media_types,
@@ -163,6 +170,10 @@ async def delete_source(
files = await server.file_manager.list_files(source_id, actor)
file_ids = [f.id for f in files]
if should_use_pinecone():
logger.info(f"Deleting source {source_id} from pinecone index")
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
for agent_state in agent_states:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
@@ -326,16 +337,24 @@ async def get_file_metadata(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Verify the source exists and user has access
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if not source:
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
# Get file metadata using the file manager
file_metadata = await server.file_manager.get_file_by_id(
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
)
if should_use_pinecone() and not file_metadata.is_processing_terminal():
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks)
logger.info(f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} in organization {actor.organization_id}")
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
if len(ids) != file_metadata.total_chunks:
file_status = file_metadata.processing_status
else:
file_status = FileProcessingStatus.COMPLETED
await server.file_manager.update_file_status(
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
)
if not file_metadata:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
@@ -364,6 +383,10 @@ async def delete_file_from_source(
await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
if should_use_pinecone():
logger.info(f"Deleting file {file_id} from pinecone index")
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
if deleted_file is None:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
@@ -402,8 +425,14 @@ async def load_file_to_source_cloud(
):
file_processor = MistralFileParser()
text_chunker = LlamaIndexChunker(chunk_size=embedding_config.embedding_chunk_size)
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
using_pinecone = should_use_pinecone()
if using_pinecone:
embedder = PineconeEmbedder()
else:
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(
file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor, using_pinecone=using_pinecone
)
await file_processor.process(
server=server, agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata
)

View File

@@ -109,15 +109,17 @@ class FileManager:
actor: PydanticUser,
processing_status: Optional[FileProcessingStatus] = None,
error_message: Optional[str] = None,
total_chunks: Optional[int] = None,
chunks_embedded: Optional[int] = None,
) -> PydanticFileMetadata:
"""
Update processing_status and/or error_message on a FileMetadata row.
Update processing_status, error_message, total_chunks, and/or chunks_embedded on a FileMetadata row.
* 1st round-trip → UPDATE
* 2nd round-trip → SELECT fresh row (same as read_async)
"""
if processing_status is None and error_message is None:
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
raise ValueError("Nothing to update")
values: dict[str, object] = {"updated_at": datetime.utcnow()}
@@ -125,6 +127,10 @@ class FileManager:
values["processing_status"] = processing_status
if error_message is not None:
values["error_message"] = error_message
if total_chunks is not None:
values["total_chunks"] = total_chunks
if chunks_embedded is not None:
values["chunks_embedded"] = chunks_embedded
async with db_registry.async_session() as session:
# Fast in-place update no ORM hydration

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
from letta.log import get_logger
from letta.schemas.passage import Passage
from letta.schemas.user import User
logger = get_logger(__name__)
class BaseEmbedder(ABC):
"""Abstract base class for embedding generation"""
@abstractmethod
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings for chunks with batching and concurrent processing"""

View File

@@ -9,12 +9,13 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
from letta.settings import model_settings
logger = get_logger(__name__)
class OpenAIEmbedder:
class OpenAIEmbedder(BaseEmbedder):
"""OpenAI-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
@@ -24,6 +25,7 @@ class OpenAIEmbedder:
else EmbeddingConfig.default_config(model_name="letta")
)
self.embedding_config = embedding_config or self.default_embedding_config
self.max_concurrent_requests = 20
# TODO: Unify to global OpenAI client
self.client: OpenAIClient = cast(
@@ -34,7 +36,6 @@ class OpenAIEmbedder:
actor=None, # Not necessary
),
)
self.max_concurrent_requests = 20
@trace_method
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:

View File

@@ -0,0 +1,74 @@
from typing import List
from letta.helpers.pinecone_utils import upsert_file_records_to_pinecone_index
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
try:
PINECONE_AVAILABLE = True
except ImportError:
PINECONE_AVAILABLE = False
logger = get_logger(__name__)
class PineconeEmbedder(BaseEmbedder):
"""Pinecone-based embedding generation"""
def __init__(self):
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone package is not installed. Install it with: pip install pinecone")
super().__init__()
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings and upsert to Pinecone, then return Passage objects"""
if not chunks:
return []
logger.info(f"Upserting {len(chunks)} chunks to Pinecone using namespace {source_id}")
log_event(
"embedder.generation_started",
{
"total_chunks": len(chunks),
"file_id": file_id,
"source_id": source_id,
},
)
# Upsert records to Pinecone using source_id as namespace
try:
await upsert_file_records_to_pinecone_index(file_id=file_id, source_id=source_id, chunks=chunks, actor=actor)
logger.info(f"Successfully kicked off upserting {len(chunks)} records to Pinecone")
log_event(
"embedder.upsert_started",
{"records_upserted": len(chunks), "namespace": source_id, "file_id": file_id},
)
except Exception as e:
logger.error(f"Failed to upsert records to Pinecone: {str(e)}")
log_event("embedder.upsert_failed", {"error": str(e), "error_type": type(e).__name__})
raise
# Create Passage objects (without embeddings since Pinecone handles them)
passages = []
for i, text in enumerate(chunks):
passage = Passage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=None, # Pinecone handles embeddings internally
embedding_config=None, # None
organization_id=actor.organization_id,
)
passages.append(passage)
logger.info(f"Successfully created {len(passages)} passages")
log_event(
"embedder.generation_completed",
{"passages_created": len(passages), "total_chunks_processed": len(chunks), "file_id": file_id, "source_id": source_id},
)
return passages

View File

@@ -11,7 +11,7 @@ from letta.server.server import SyncServer
from letta.services.file_manager import FileManager
from letta.services.file_processor.chunker.line_chunker import LineChunker
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.services.job_manager import JobManager
from letta.services.passage_manager import PassageManager
@@ -27,8 +27,9 @@ class FileProcessor:
self,
file_parser: MistralFileParser,
text_chunker: LlamaIndexChunker,
embedder: OpenAIEmbedder,
embedder: BaseEmbedder,
actor: User,
using_pinecone: bool,
max_file_size: int = 50 * 1024 * 1024, # 50MB default
):
self.file_parser = file_parser
@@ -41,6 +42,7 @@ class FileProcessor:
self.passage_manager = PassageManager()
self.job_manager = JobManager()
self.actor = actor
self.using_pinecone = using_pinecone
# TODO: Factor this function out of SyncServer
@trace_method
@@ -109,7 +111,7 @@ class FileProcessor:
logger.info("Chunking extracted text")
log_event("file_processor.chunking_started", {"filename": filename, "pages_to_process": len(ocr_response.pages)})
all_passages = []
all_chunks = []
for page in ocr_response.pages:
chunks = self.text_chunker.chunk_text(page)
@@ -118,24 +120,17 @@ class FileProcessor:
log_event("file_processor.chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)})
raise ValueError("No chunks created from text")
passages = await self.embedder.generate_embedded_passages(
file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor
)
log_event(
"file_processor.page_processed",
{
"filename": filename,
"page_index": ocr_response.pages.index(page),
"chunks_created": len(chunks),
"passages_generated": len(passages),
},
)
all_passages.extend(passages)
all_chunks.extend(self.text_chunker.chunk_text(page))
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages, file_metadata=file_metadata, actor=self.actor
all_passages = await self.embedder.generate_embedded_passages(
file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor
)
log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
if not self.using_pinecone:
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages, file_metadata=file_metadata, actor=self.actor
)
log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
log_event(
@@ -149,9 +144,14 @@ class FileProcessor:
)
# update job status
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
)
if not self.using_pinecone:
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
)
else:
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, total_chunks=len(all_passages), chunks_embedded=0
)
return all_passages

View File

@@ -115,10 +115,6 @@ class JobManager:
job.completed_at = get_utc_time().replace(tzinfo=None)
if job.callback_url:
await self._dispatch_callback_async(job)
else:
logger.info(f"Job does not contain callback url: {job}")
else:
logger.info(f"Job update is not terminal {job_update}")
# Save the updated job to the database
await job.update_async(db_session=session, actor=actor)

View File

@@ -19,7 +19,6 @@ class SourceManager:
@trace_method
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
# Try getting the source first by id
db_source = await self.get_source_by_id(source.id, actor=actor)
if db_source:
return db_source

View File

@@ -2,8 +2,9 @@ import asyncio
import re
from typing import Any, Dict, List, Optional
from letta.constants import MAX_FILES_OPEN
from letta.constants import MAX_FILES_OPEN, PINECONE_TEXT_FIELD_NAME
from letta.functions.types import FileOpenRequest
from letta.helpers.pinecone_utils import search_pinecone_index, should_use_pinecone
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -463,14 +464,15 @@ class LettaFileToolExecutor(ToolExecutor):
return "\n".join(formatted_results)
@trace_method
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 5) -> str:
"""
Search for text within attached files using semantic search and return passages with their source filenames.
Uses Pinecone if configured, otherwise falls back to traditional search.
Args:
agent_state: Current agent state
query: Search query for semantic matching
limit: Maximum number of results to return (default: 10)
limit: Maximum number of results to return (default: 5)
Returns:
Formatted string with search results in IDE/terminal style
@@ -485,6 +487,110 @@ class LettaFileToolExecutor(ToolExecutor):
self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})")
# Check if Pinecone is enabled and use it if available
if should_use_pinecone():
return await self._search_files_pinecone(agent_state, query, limit)
else:
return await self._search_files_traditional(agent_state, query, limit)
async def _search_files_pinecone(self, agent_state: AgentState, query: str, limit: int) -> str:
"""Search files using Pinecone vector database."""
# Extract unique source_ids
# TODO: Inefficient
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
source_ids = [source.id for source in attached_sources]
if not source_ids:
return f"No valid source IDs found for attached files"
# Get all attached files for this agent
file_agents = await self.files_agents_manager.list_files_for_agent(agent_id=agent_state.id, actor=self.actor)
if not file_agents:
return "No files are currently attached to search"
results = []
total_hits = 0
files_with_matches = {}
try:
filter = {"source_id": {"$in": source_ids}}
search_results = await search_pinecone_index(query, limit, filter, self.actor)
# Process search results
if "result" in search_results and "hits" in search_results["result"]:
for hit in search_results["result"]["hits"]:
if total_hits >= limit:
break
total_hits += 1
# Extract hit information
hit_id = hit.get("_id", "unknown")
score = hit.get("_score", 0.0)
fields = hit.get("fields", {})
text = fields.get(PINECONE_TEXT_FIELD_NAME, "")
file_id = fields.get("file_id", "")
# Find corresponding file name
file_name = "Unknown File"
for fa in file_agents:
if fa.file_id == file_id:
file_name = fa.file_name
break
# Group by file name
if file_name not in files_with_matches:
files_with_matches[file_name] = []
files_with_matches[file_name].append({"text": text, "score": score, "hit_id": hit_id})
except Exception as e:
self.logger.error(f"Pinecone search failed: {str(e)}")
raise e
if not files_with_matches:
return f"No semantic matches found in Pinecone for query: '{query}'"
# Format results
passage_num = 0
for file_name, matches in files_with_matches.items():
for match in matches:
passage_num += 1
# Format each passage with terminal-style header
score_display = f"(score: {match['score']:.3f})"
passage_header = f"\n=== {file_name} (passage #{passage_num}) {score_display} ==="
# Format the passage text
passage_text = match["text"].strip()
lines = passage_text.splitlines()
formatted_lines = []
for line in lines[:20]: # Limit to first 20 lines per passage
formatted_lines.append(f" {line}")
if len(lines) > 20:
formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]")
passage_content = "\n".join(formatted_lines)
results.append(f"{passage_header}\n{passage_content}")
# Mark access for files that had matches
if files_with_matches:
matched_file_names = [name for name in files_with_matches.keys() if name != "Unknown File"]
if matched_file_names:
await self.files_agents_manager.mark_access_bulk(agent_id=agent_state.id, file_names=matched_file_names, actor=self.actor)
# Create summary header
file_count = len(files_with_matches)
summary = f"Found {total_hits} Pinecone matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'"
# Combine all results
formatted_results = [summary, "=" * len(summary)] + results
self.logger.info(f"Pinecone search completed: {total_hits} matches across {file_count} files")
return "\n".join(formatted_results)
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
"""Traditional search using existing passage manager."""
# Get semantic search results
passages = await self.agent_manager.list_source_passages_async(
actor=self.actor,

View File

@@ -253,6 +253,13 @@ class Settings(BaseSettings):
llm_request_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM requests in seconds")
llm_stream_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds")
# For embeddings
enable_pinecone: bool = False
pinecone_api_key: Optional[str] = None
pinecone_source_index: Optional[str] = "sources"
pinecone_agent_index: Optional[str] = "recall"
upsert_pinecone_indices: bool = False
@property
def letta_pg_uri(self) -> str:
if self.pg_uri:

594
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -98,6 +98,7 @@ redis = {version = "^6.2.0", optional = true}
structlog = "^25.4.0"
certifi = "^2025.6.15"
aioboto3 = {version = "^14.3.0", optional = true}
pinecone = {extras = ["asyncio"], version = "^7.3.0"}
[tool.poetry.extras]
@@ -119,6 +120,7 @@ black = "^24.4.2"
ipykernel = "^6.29.5"
ipdb = "^0.13.13"
pytest-mock = "^3.14.0"
pinecone = "^7.3.0"
[tool.poetry.group."dev,tests".dependencies]

View File

@@ -28,6 +28,24 @@ def disable_e2b_api_key() -> Generator[None, None, None]:
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def disable_pinecone() -> Generator[None, None, None]:
"""
Temporarily disables Pinecone by setting `settings.enable_pinecone` to False
and `settings.pinecone_api_key` to None for the duration of the test.
Restores the original values afterward.
"""
from letta.settings import settings
original_enable_pinecone = settings.enable_pinecone
original_pinecone_api_key = settings.pinecone_api_key
settings.enable_pinecone = False
settings.pinecone_api_key = None
yield
settings.enable_pinecone = original_enable_pinecone
settings.pinecone_api_key = original_pinecone_api_key
@pytest.fixture
def check_e2b_key_is_set():
from letta.settings import tool_settings

View File

@@ -3320,7 +3320,7 @@ async def test_update_tool_pip_requirements(server: SyncServer, print_tool, defa
# Add pip requirements to existing tool
pip_reqs = [
PipRequirement(name="pandas", version="1.5.0"),
PipRequirement(name="matplotlib"),
PipRequirement(name="sumy"),
]
tool_update = ToolUpdate(pip_requirements=pip_reqs)
@@ -3334,7 +3334,7 @@ async def test_update_tool_pip_requirements(server: SyncServer, print_tool, defa
assert len(updated_tool.pip_requirements) == 2
assert updated_tool.pip_requirements[0].name == "pandas"
assert updated_tool.pip_requirements[0].version == "1.5.0"
assert updated_tool.pip_requirements[1].name == "matplotlib"
assert updated_tool.pip_requirements[1].name == "sumy"
assert updated_tool.pip_requirements[1].version is None
@@ -5218,6 +5218,41 @@ async def test_update_file_status_error_only(server, default_user, default_sourc
assert updated.processing_status == FileProcessingStatus.PENDING # default from creation
@pytest.mark.asyncio
async def test_update_file_status_with_chunks(server, default_user, default_source):
"""Update chunk progress fields along with status."""
meta = PydanticFileMetadata(
file_name="chunks_test.txt",
file_path="/tmp/chunks_test.txt",
file_type="text/plain",
file_size=500,
source_id=default_source.id,
)
created = await server.file_manager.create_file(file_metadata=meta, actor=default_user)
# Update with chunk progress
updated = await server.file_manager.update_file_status(
file_id=created.id,
actor=default_user,
processing_status=FileProcessingStatus.EMBEDDING,
total_chunks=100,
chunks_embedded=50,
)
assert updated.processing_status == FileProcessingStatus.EMBEDDING
assert updated.total_chunks == 100
assert updated.chunks_embedded == 50
# Update only chunk progress
updated = await server.file_manager.update_file_status(
file_id=created.id,
actor=default_user,
chunks_embedded=100,
)
assert updated.chunks_embedded == 100
assert updated.total_chunks == 100 # unchanged
assert updated.processing_status == FileProcessingStatus.EMBEDDING # unchanged
@pytest.mark.asyncio
async def test_upsert_file_content_basic(server: SyncServer, default_user, default_source, async_session):
"""Test creating and updating file content with upsert_file_content()."""

View File

@@ -9,9 +9,10 @@ from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client.types import AgentState
from letta.constants import FILES_TOOLS
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
from letta.orm.enums import ToolType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
from tests.utils import wait_for_server
# Constants
@@ -49,7 +50,7 @@ def client() -> LettaSDKClient:
yield client
def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 30):
def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 60):
"""Helper function to upload a file and wait for processing to complete"""
with open(file_path, "rb") as f:
file_metadata = client.sources.files.upload(source_id=source_id, file=f)
@@ -70,7 +71,7 @@ def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str,
@pytest.fixture
def agent_state(client: LettaSDKClient):
def agent_state(disable_pinecone, client: LettaSDKClient):
open_file_tool = client.tools.list(name="open_files")[0]
search_files_tool = client.tools.list(name="semantic_search_files")[0]
grep_tool = client.tools.list(name="grep_files")[0]
@@ -93,7 +94,7 @@ def agent_state(client: LettaSDKClient):
# Tests
def test_auto_attach_detach_files_tools(client: LettaSDKClient):
def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient):
"""Test automatic attachment and detachment of file tools when managing agent sources."""
# Create agent with basic configuration
agent = client.agents.create(
@@ -164,6 +165,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
],
)
def test_file_upload_creates_source_blocks_correctly(
disable_pinecone,
client: LettaSDKClient,
agent_state: AgentState,
file_path: str,
@@ -204,7 +206,7 @@ def test_file_upload_creates_source_blocks_correctly(
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -240,7 +242,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
assert not any("test" in b.value for b in blocks)
def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -270,7 +272,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
assert not any("test" in b.value for b in blocks)
def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -377,7 +379,7 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat
print("✓ File successfully opened with different range - content differs as expected")
def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -423,7 +425,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -465,7 +467,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -517,7 +519,7 @@ def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state:
assert "513:" in tool_return_message.tool_return
def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: LettaSDKClient):
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient):
"""Test that creating an agent with source_ids parameter correctly creates source blocks."""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -560,7 +562,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le
assert file_tools == set(FILES_TOOLS)
def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState):
def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -623,7 +625,7 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
)
def test_duplicate_file_renaming(client: LettaSDKClient):
def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
# Create a new source
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
@@ -662,7 +664,7 @@ def test_duplicate_file_renaming(client: LettaSDKClient):
print(f" File {i+1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
def test_open_files_schema_descriptions(client: LettaSDKClient):
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
"""Test that open_files tool schema contains correct descriptions from docstring"""
# Get the open_files tool
@@ -743,3 +745,132 @@ def test_open_files_schema_descriptions(client: LettaSDKClient):
expected_length_desc = "Optional number of lines to view from offset (inclusive). If not specified, views to end of file."
assert length_prop["description"] == expected_length_desc
assert length_prop["type"] == "integer"
# --- Pinecone Tests ---
def test_pinecone_search_files_tool(client: LettaSDKClient):
"""Test that search_files tool uses Pinecone when enabled"""
from letta.helpers.pinecone_utils import should_use_pinecone
if not should_use_pinecone(verbose=True):
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
print("Testing Pinecone search_files tool functionality")
# Create agent with file tools
agent = client.agents.create(
name="test_pinecone_agent",
memory_blocks=[
CreateBlock(label="human", value="username: testuser"),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
)
# Create source and attach to agent
source = client.sources.create(name="test_pinecone_source", embedding="openai/text-embedding-3-small")
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
# Upload a file with searchable content
file_path = "tests/data/long_test.txt"
upload_file_and_wait(client, source.id, file_path)
# Test semantic search using Pinecone
search_response = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Use the semantic_search_files tool to search for 'electoral history' in the files.")],
)
# Verify tool was called successfully
tool_calls = [msg for msg in search_response.messages if msg.message_type == "tool_call_message"]
assert len(tool_calls) > 0, "No tool calls found"
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
# Verify tool returned results
tool_returns = [msg for msg in search_response.messages if msg.message_type == "tool_return_message"]
assert len(tool_returns) > 0, "No tool returns found"
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
# Check that results contain expected content
search_results = tool_returns[0].tool_return
print(search_results)
assert (
"electoral" in search_results.lower() or "history" in search_results.lower()
), f"Search results should contain relevant content: {search_results}"
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
"""Test that file and source deletion removes records from Pinecone"""
import asyncio
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
if not should_use_pinecone():
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
print("Testing Pinecone file and source deletion lifecycle")
# Create source
source = client.sources.create(name="test_lifecycle_source", embedding="openai/text-embedding-3-small")
# Upload multiple files and wait for processing
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
uploaded_files = []
for file_path in file_paths:
file_metadata = upload_file_and_wait(client, source.id, file_path)
uploaded_files.append(file_metadata)
# Get temp user for Pinecone operations
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
# Test file-level deletion first
if len(uploaded_files) > 1:
file_to_delete = uploaded_files[0]
# Check records for the specific file using list function
records_before = asyncio.run(list_pinecone_index_for_files(file_to_delete.id, user))
print(f"Found {len(records_before)} records for file before deletion")
# Delete the file
client.sources.files.delete(source_id=source.id, file_id=file_to_delete.id)
# Allow time for deletion to propagate
time.sleep(2)
# Verify file records are removed
records_after = asyncio.run(list_pinecone_index_for_files(file_to_delete.id, user))
print(f"Found {len(records_after)} records for file after deletion")
assert len(records_after) == 0, f"File records should be removed from Pinecone after deletion, but found {len(records_after)}"
# Test source-level deletion - check remaining files
# Check records for remaining files
remaining_records = []
for file_metadata in uploaded_files[1:]: # Skip the already deleted file
file_records = asyncio.run(list_pinecone_index_for_files(file_metadata.id, user))
remaining_records.extend(file_records)
records_before = len(remaining_records)
print(f"Found {records_before} records for remaining files before source deletion")
# Delete the entire source
client.sources.delete(source_id=source.id)
# Allow time for deletion to propagate
time.sleep(3)
# Verify all remaining file records are removed
records_after = []
for file_metadata in uploaded_files[1:]:
file_records = asyncio.run(list_pinecone_index_for_files(file_metadata.id, user))
records_after.extend(file_records)
print(f"Found {len(records_after)} records for files after source deletion")
assert (
len(records_after) == 0
), f"All source records should be removed from Pinecone after source deletion, but found {len(records_after)}"
print("✓ Pinecone lifecycle verified - namespace is clean after source deletion")