From f4740b138812007ecd247cc49f5192b1bd57fcdc Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 12 Aug 2025 15:11:09 -0700 Subject: [PATCH] chore: remove legacy embeddings (#3846) --- letta/data_sources/connectors.py | 127 +++++---- letta/embeddings.py | 242 +----------------- letta/functions/function_sets/base.py | 8 +- letta/schemas/providers/ollama.py | 2 +- letta/server/server.py | 33 +-- letta/services/agent_manager.py | 24 +- .../services/helpers/agent_manager_helper.py | 38 ++- letta/services/passage_manager.py | 137 ++-------- .../tool_executor/core_tool_executor.py | 2 +- tests/helpers/endpoints_helper.py | 15 +- tests/test_base_functions.py | 41 +-- tests/test_sdk_client.py | 22 ++ tests/test_server.py | 50 ---- tests/test_sources.py | 2 +- 14 files changed, 210 insertions(+), 533 deletions(-) diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 3248e83c..cfafe2a2 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -4,7 +4,6 @@ import typer from letta.constants import EMBEDDING_BATCH_SIZE from letta.data_sources.connectors_helper import assert_all_files_exist_locally, extract_metadata_from_files, get_filenames_in_dir -from letta.embeddings import embedding_model from letta.schemas.file import FileMetadata from letta.schemas.passage import Passage from letta.schemas.source import Source @@ -40,61 +39,29 @@ class DataConnector: async def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, file_manager: FileManager, actor: "User"): from letta.llm_api.llm_client import LLMClient - from letta.schemas.embedding_config import EmbeddingConfig """Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id.""" embedding_config = source.embedding_config # insert passages/file - texts = [] embedding_to_document_name = {} passage_count = 0 file_count = 0 - async def generate_embeddings(texts: List[str], embedding_config: EmbeddingConfig) -> List[Passage]: - passages = [] - if embedding_config.embedding_endpoint_type == "openai": - texts.append(passage_text) - - client = LLMClient.create( - provider_type=embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await client.request_embeddings(texts, embedding_config) - - else: - embed_model = embedding_model(embedding_config) - embeddings = [embed_model.get_text_embedding(text) for text in texts] - - # collate passage and embedding - for text, embedding in zip(texts, embeddings): - passage = Passage( - text=text, - file_id=file_metadata.id, - source_id=source.id, - metadata=passage_metadata, - organization_id=source.organization_id, - embedding_config=source.embedding_config, - embedding=embedding, - ) - hashable_embedding = tuple(passage.embedding) - file_name = file_metadata.file_name - if hashable_embedding in embedding_to_document_name: - typer.secho( - f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.", - fg=typer.colors.YELLOW, - ) - continue - - passages.append(passage) - embedding_to_document_name[hashable_embedding] = file_name - return passages + # Use the new LLMClient for all embedding requests + client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) for file_metadata in connector.find_files(source): file_count += 1 await file_manager.create_file(file_metadata, actor) - # generate passages + # generate passages for this file + texts = [] + metadatas = [] + for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size): # for some reason, llama index parsers sometimes return empty strings if len(passage_text) == 0: @@ -104,24 +71,74 @@ async def load_data(connector: DataConnector, source: Source, passage_manager: P ) continue - # get embedding texts.append(passage_text) - if len(texts) >= EMBEDDING_BATCH_SIZE: - passages = await generate_embeddings(texts, embedding_config) - texts = [] - else: - continue + metadatas.append(passage_metadata) + + if len(texts) >= EMBEDDING_BATCH_SIZE: + # Process the batch + embeddings = await client.request_embeddings(texts, embedding_config) + passages = [] + + for text, embedding, passage_metadata in zip(texts, embeddings, metadatas): + passage = Passage( + text=text, + file_id=file_metadata.id, + source_id=source.id, + metadata=passage_metadata, + organization_id=source.organization_id, + embedding_config=source.embedding_config, + embedding=embedding, + ) + hashable_embedding = tuple(passage.embedding) + file_name = file_metadata.file_name + if hashable_embedding in embedding_to_document_name: + typer.secho( + f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.", + fg=typer.colors.YELLOW, + ) + continue + + passages.append(passage) + embedding_to_document_name[hashable_embedding] = file_name + + # insert passages into passage store + await passage_manager.create_many_passages_async(passages, actor) + passage_count += len(passages) + + # Reset for next batch + texts = [] + metadatas = [] + + # Process final remaining texts for this file + if len(texts) > 0: + embeddings = await client.request_embeddings(texts, embedding_config) + passages = [] + + for text, embedding, passage_metadata in zip(texts, embeddings, metadatas): + passage = Passage( + text=text, + file_id=file_metadata.id, + source_id=source.id, + metadata=passage_metadata, + organization_id=source.organization_id, + embedding_config=source.embedding_config, + embedding=embedding, + ) + hashable_embedding = tuple(passage.embedding) + file_name = file_metadata.file_name + if hashable_embedding in embedding_to_document_name: + typer.secho( + f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.", + fg=typer.colors.YELLOW, + ) + continue + + passages.append(passage) + embedding_to_document_name[hashable_embedding] = file_name - # insert passages into passage store await passage_manager.create_many_passages_async(passages, actor) passage_count += len(passages) - # final remaining - if len(texts) > 0: - passages = await generate_embeddings(texts, embedding_config) - await passage_manager.create_many_passages_async(passages, actor) - passage_count += len(passages) - return passage_count, file_count diff --git a/letta/embeddings.py b/letta/embeddings.py index f302c6fb..a07e16c3 100644 --- a/letta/embeddings.py +++ b/letta/embeddings.py @@ -1,13 +1,9 @@ -import uuid -from typing import Any, List, Optional +from typing import List -import numpy as np import tiktoken -from openai import OpenAI -from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, MAX_EMBEDDING_DIM -from letta.schemas.embedding_config import EmbeddingConfig -from letta.utils import is_valid_url, printd +from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP +from letta.utils import printd def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]: @@ -55,235 +51,3 @@ def check_and_split_text(text: str, embedding_model: str) -> List[str]: text = truncate_text(formatted_text, max_length, encoding) return [text] - - -class EmbeddingEndpoint: - """Implementation for OpenAI compatible endpoint""" - - # """ Based off llama index https://github.com/run-llama/llama_index/blob/a98bdb8ecee513dc2e880f56674e7fd157d1dc3a/llama_index/embeddings/text_embeddings_inference.py """ - - # _user: str = PrivateAttr() - # _timeout: float = PrivateAttr() - # _base_url: str = PrivateAttr() - - def __init__( - self, - model: str, - base_url: str, - user: str, - timeout: float = 60.0, - **kwargs: Any, - ): - if not is_valid_url(base_url): - raise ValueError( - f"Embeddings endpoint was provided an invalid URL (set to: '{base_url}'). Make sure embedding_endpoint is set correctly in your Letta config." - ) - # TODO: find a neater solution - re-mapping for letta endpoint - if model == "letta-free": - model = "BAAI/bge-large-en-v1.5" - self.model_name = model - self._user = user - self._base_url = base_url - self._timeout = timeout - - def _call_api(self, text: str) -> List[float]: - if not is_valid_url(self._base_url): - raise ValueError( - f"Embeddings endpoint does not have a valid URL (set to: '{self._base_url}'). Make sure embedding_endpoint is set correctly in your Letta config." - ) - import httpx - - headers = {"Content-Type": "application/json"} - json_data = {"input": text, "model": self.model_name, "user": self._user} - - with httpx.Client() as client: - response = client.post( - f"{self._base_url}/embeddings", - headers=headers, - json=json_data, - timeout=self._timeout, - ) - - response_json = response.json() - - if isinstance(response_json, list): - # embedding directly in response - embedding = response_json - elif isinstance(response_json, dict): - # TEI embedding packaged inside openai-style response - try: - embedding = response_json["data"][0]["embedding"] - except (KeyError, IndexError): - raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}") - else: - # unknown response, can't parse - raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}") - - return embedding - - def get_text_embedding(self, text: str) -> List[float]: - return self._call_api(text) - - -class AzureOpenAIEmbedding: - def __init__(self, api_endpoint: str, api_key: str, api_version: str, model: str): - from openai import AzureOpenAI - - self.client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_endpoint) - self.model = model - - def get_text_embedding(self, text: str): - embeddings = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding - return embeddings - - -class OllamaEmbeddings: - - # Format: - # curl http://localhost:11434/api/embeddings -d '{ - # "model": "mxbai-embed-large", - # "prompt": "Llamas are members of the camelid family" - # }' - - def __init__(self, model: str, base_url: str, ollama_additional_kwargs: dict): - self.model = model - self.base_url = base_url - self.ollama_additional_kwargs = ollama_additional_kwargs - - def get_text_embedding(self, text: str): - import httpx - - headers = {"Content-Type": "application/json"} - json_data = {"model": self.model, "prompt": text} - json_data.update(self.ollama_additional_kwargs) - - with httpx.Client() as client: - response = client.post( - f"{self.base_url}/api/embeddings", - headers=headers, - json=json_data, - ) - - response_json = response.json() - return response_json["embedding"] - - -class GoogleEmbeddings: - def __init__(self, api_key: str, model: str, base_url: str): - self.api_key = api_key - self.model = model - self.base_url = base_url # Expected to be "https://generativelanguage.googleapis.com" - - def get_text_embedding(self, text: str): - import httpx - - headers = {"Content-Type": "application/json"} - # Build the URL based on the provided base_url, model, and API key. - url = f"{self.base_url}/v1beta/models/{self.model}:embedContent?key={self.api_key}" - payload = {"model": self.model, "content": {"parts": [{"text": text}]}} - with httpx.Client() as client: - response = client.post(url, headers=headers, json=payload) - # Raise an error for non-success HTTP status codes. - response.raise_for_status() - response_json = response.json() - return response_json["embedding"]["values"] - - -class GoogleVertexEmbeddings: - def __init__(self, model: str, project_id: str, region: str): - from google import genai - - self.client = genai.Client(vertexai=True, project=project_id, location=region, http_options={"api_version": "v1"}) - self.model = model - - def get_text_embedding(self, text: str): - response = self.client.generate_embeddings(content=text, model=self.model) - return response.embeddings[0].embedding - - -class OpenAIEmbeddings: - def __init__(self, api_key: str, model: str, base_url: str): - if base_url: - self.client = OpenAI(api_key=api_key, base_url=base_url) - else: - self.client = OpenAI(api_key=api_key) - self.model = model - - def get_text_embedding(self, text: str): - response = self.client.embeddings.create(input=text, model=self.model) - - return response.data[0].embedding - - -def query_embedding(embedding_model, query_text: str): - """Generate padded embedding for querying database""" - query_vec = embedding_model.get_text_embedding(query_text) - query_vec = np.array(query_vec) - query_vec = np.pad(query_vec, (0, MAX_EMBEDDING_DIM - query_vec.shape[0]), mode="constant").tolist() - return query_vec - - -def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None): - """Return LlamaIndex embedding model to use for embeddings""" - - endpoint_type = config.embedding_endpoint_type - - # TODO: refactor to pass in settings from server - from letta.settings import model_settings - - if endpoint_type == "openai": - return OpenAIEmbeddings( - api_key=model_settings.openai_api_key, - model=config.embedding_model, - base_url=config.embedding_endpoint or model_settings.openai_api_base, - ) - - elif endpoint_type == "azure": - assert all( - [ - model_settings.azure_api_key is not None, - model_settings.azure_base_url is not None, - model_settings.azure_api_version is not None, - ] - ) - return AzureOpenAIEmbedding( - api_endpoint=model_settings.azure_base_url, - api_key=model_settings.azure_api_key, - api_version=model_settings.azure_api_version, - model=config.embedding_model, - ) - - elif endpoint_type == "hugging-face": - return EmbeddingEndpoint( - model=config.embedding_model, - base_url=config.embedding_endpoint, - user=user_id, - ) - elif endpoint_type == "ollama": - - model = OllamaEmbeddings( - model=config.embedding_model, - base_url=config.embedding_endpoint, - ollama_additional_kwargs={}, - ) - return model - - elif endpoint_type == "google_ai": - assert all([model_settings.gemini_api_key is not None, model_settings.gemini_base_url is not None]) - model = GoogleEmbeddings( - model=config.embedding_model, - api_key=model_settings.gemini_api_key, - base_url=model_settings.gemini_base_url, - ) - return model - - elif endpoint_type == "google_vertex": - model = GoogleVertexEmbeddings( - model=config.embedding_model, - api_key=model_settings.gemini_api_key, - base_url=model_settings.gemini_base_url, - ) - return model - - else: - raise ValueError(f"Unknown endpoint type {endpoint_type}") diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index ee8df51b..d0e2e94c 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -63,7 +63,7 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O return results_str -def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: +async def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. @@ -73,7 +73,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: Returns: Optional[str]: None is always returned as this function does not produce a response. """ - self.passage_manager.insert_passage( + await self.passage_manager.insert_passage( agent_state=self.agent_state, text=content, actor=self.user, @@ -82,7 +82,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: return None -def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]: +async def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. @@ -107,7 +107,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s try: # Get results using passage manager - all_results = self.agent_manager.list_passages( + all_results = await self.agent_manager.list_passages_async( actor=self.user, agent_id=self.agent_state.id, query_text=query, diff --git a/letta/schemas/providers/ollama.py b/letta/schemas/providers/ollama.py index df7d8346..d34d86d7 100644 --- a/letta/schemas/providers/ollama.py +++ b/letta/schemas/providers/ollama.py @@ -74,7 +74,7 @@ class OllamaProvider(OpenAIProvider): configs = [] for model in response_json["models"]: - embedding_dim = await self._get_model_embedding_dim_async(model["name"]) + embedding_dim = await self._get_model_embedding_dim(model["name"]) if not embedding_dim: print(f"Ollama model {model['name']} has no embedding dimension, using default 1024") # continue diff --git a/letta/server/server.py b/letta/server/server.py index 468f1cf7..8b396594 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1099,33 +1099,6 @@ class SyncServer(Server): def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary: return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id)) - def get_agent_archival( - self, - user_id: str, - agent_id: str, - after: Optional[str] = None, - before: Optional[str] = None, - limit: Optional[int] = 100, - order_by: Optional[str] = "created_at", - reverse: Optional[bool] = False, - query_text: Optional[str] = None, - ascending: Optional[bool] = True, - ) -> List[Passage]: - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - - # iterate over records - records = self.agent_manager.list_passages( - actor=actor, - agent_id=agent_id, - after=after, - query_text=query_text, - before=before, - ascending=ascending, - limit=limit, - ) - return records - async def get_agent_archival_async( self, agent_id: str, @@ -1153,7 +1126,7 @@ class SyncServer(Server): agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) # Insert passages into the archive - passages = await self.passage_manager.insert_passage_async(agent_state=agent_state, text=memory_contents, actor=actor) + passages = await self.passage_manager.insert_passage(agent_state=agent_state, text=memory_contents, actor=actor) return passages @@ -1471,10 +1444,6 @@ class SyncServer(Server): passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor) return passage_count, document_count - def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: - # TODO: move this query into PassageManager - return self.agent_manager.list_passages(actor=self.user_manager.get_user_or_default(user_id=user_id), source_id=source_id) - def list_all_sources(self, actor: User) -> List[Source]: # TODO: legacy: remove """List all sources (w/ extra metadata) belonging to a user""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 5751ac40..65e03d37 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -2354,7 +2354,7 @@ class AgentManager: @enforce_types @trace_method - def list_passages( + async def list_passages( self, actor: PydanticUser, agent_id: Optional[str] = None, @@ -2372,8 +2372,8 @@ class AgentManager: agent_only: bool = False, ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" - with db_registry.session() as session: - main_query = build_passage_query( + async with db_registry.async_session() as session: + main_query = await build_passage_query( actor=actor, agent_id=agent_id, file_id=file_id, @@ -2394,7 +2394,7 @@ class AgentManager: main_query = main_query.limit(limit) # Execute query - result = session.execute(main_query) + result = await session.execute(main_query) passages = [] for row in result: @@ -2437,7 +2437,7 @@ class AgentManager: ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" async with db_registry.async_session() as session: - main_query = build_passage_query( + main_query = await build_passage_query( actor=actor, agent_id=agent_id, file_id=file_id, @@ -2500,7 +2500,7 @@ class AgentManager: ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" async with db_registry.async_session() as session: - main_query = build_source_passage_query( + main_query = await build_source_passage_query( actor=actor, agent_id=agent_id, file_id=file_id, @@ -2546,7 +2546,7 @@ class AgentManager: ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" async with db_registry.async_session() as session: - main_query = build_agent_passage_query( + main_query = await build_agent_passage_query( actor=actor, agent_id=agent_id, query_text=query_text, @@ -2574,7 +2574,7 @@ class AgentManager: @enforce_types @trace_method - def passage_size( + async def passage_size( self, actor: PydanticUser, agent_id: Optional[str] = None, @@ -2591,8 +2591,8 @@ class AgentManager: agent_only: bool = False, ) -> int: """Returns the count of passages matching the given criteria.""" - with db_registry.session() as session: - main_query = build_passage_query( + async with db_registry.async_session() as session: + main_query = await build_passage_query( actor=actor, agent_id=agent_id, file_id=file_id, @@ -2610,7 +2610,7 @@ class AgentManager: # Convert to count query count_query = select(func.count()).select_from(main_query.subquery()) - return session.scalar(count_query) or 0 + return (await session.scalar(count_query)) or 0 @enforce_types async def passage_size_async( @@ -2630,7 +2630,7 @@ class AgentManager: agent_only: bool = False, ) -> int: async with db_registry.async_session() as session: - main_query = build_passage_query( + main_query = await build_passage_query( actor=actor, agent_id=agent_id, file_id=file_id, diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index a97509f4..1319175b 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -20,9 +20,9 @@ from letta.constants import ( MULTI_AGENT_TOOLS, STRUCTURED_OUTPUT_MODELS, ) -from letta.embeddings import embedding_model from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import format_datetime, get_local_time, get_local_time_fast +from letta.llm_api.llm_client import LLMClient from letta.orm.agent import Agent as AgentModel from letta.orm.agents_tags import AgentsTags from letta.orm.archives_agents import ArchivesAgents @@ -939,7 +939,7 @@ def _apply_relationship_filters(query, include_relationships: Optional[List[str] return query -def build_passage_query( +async def build_passage_query( actor: User, agent_id: Optional[str] = None, file_id: Optional[str] = None, @@ -963,8 +963,14 @@ def build_passage_query( if embed_query: assert embedding_config is not None, "embedding_config must be specified for vector search" assert query_text is not None, "query_text must be specified for vector search" - embedded_text = embedding_model(embedding_config).get_text_embedding(query_text) - embedded_text = np.array(embedded_text) + + # Use the new LLMClient for embeddings + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([query_text], embedding_config) + embedded_text = np.array(embeddings[0]) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() # Start with base query for source passages @@ -1150,7 +1156,7 @@ def build_passage_query( return main_query -def build_source_passage_query( +async def build_source_passage_query( actor: User, agent_id: Optional[str] = None, file_id: Optional[str] = None, @@ -1171,8 +1177,14 @@ def build_source_passage_query( if embed_query: assert embedding_config is not None, "embedding_config must be specified for vector search" assert query_text is not None, "query_text must be specified for vector search" - embedded_text = embedding_model(embedding_config).get_text_embedding(query_text) - embedded_text = np.array(embedded_text) + + # Use the new LLMClient for embeddings + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([query_text], embedding_config) + embedded_text = np.array(embeddings[0]) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() # Base query for source passages @@ -1248,7 +1260,7 @@ def build_source_passage_query( return query -def build_agent_passage_query( +async def build_agent_passage_query( actor: User, agent_id: str, # Required for agent passages query_text: Optional[str] = None, @@ -1267,8 +1279,14 @@ def build_agent_passage_query( if embed_query: assert embedding_config is not None, "embedding_config must be specified for vector search" assert query_text is not None, "query_text must be specified for vector search" - embedded_text = embedding_model(embedding_config).get_text_embedding(query_text) - embedded_text = np.array(embedded_text) + + # Use the new LLMClient for embeddings + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([query_text], embedding_config) + embedded_text = np.array(embeddings[0]) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() # Base query for agent passages - join through archives_agents diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 895ecef3..39cd07ce 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,4 +1,3 @@ -import asyncio from datetime import datetime, timezone from functools import lru_cache from typing import List, Optional @@ -7,8 +6,9 @@ from openai import AsyncOpenAI, OpenAI from sqlalchemy import select from letta.constants import MAX_EMBEDDING_DIM -from letta.embeddings import embedding_model, parse_and_chunk_text +from letta.embeddings import parse_and_chunk_text from letta.helpers.decorators import async_redis_cache +from letta.llm_api.llm_client import LLMClient from letta.orm import ArchivesAgents from letta.orm.errors import NoResultFound from letta.orm.passage import ArchivalPassage, SourcePassage @@ -460,7 +460,7 @@ class PassageManager: @enforce_types @trace_method - def insert_passage( + async def insert_passage( self, agent_state: AgentState, text: str, @@ -469,45 +469,32 @@ class PassageManager: """Insert passage(s) into archival memory""" embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size + embedding_client = LLMClient.create( + provider_type=agent_state.embedding_config.embedding_endpoint_type, + actor=actor, + ) - # TODO eventually migrate off of llama-index for embeddings? - # Already causing pain for OpenAI proxy endpoints like LM Studio... - if agent_state.embedding_config.embedding_endpoint_type != "openai": - embed_model = embedding_model(agent_state.embedding_config) + # Get or create the default archive for the agent + archive = await self.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=agent_state.id, agent_name=agent_state.name, actor=actor + ) - passages = [] + text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size)) + + if not text_chunks: + return [] try: - # breakup string into passages - for text in parse_and_chunk_text(text, embedding_chunk_size): - if agent_state.embedding_config.embedding_endpoint_type != "openai": - embedding = embed_model.get_text_embedding(text) - else: - # TODO should have the settings passed in via the server call - embedding = get_openai_embedding( - text, - agent_state.embedding_config.embedding_model, - agent_state.embedding_config.embedding_endpoint, - ) + # Generate embeddings for all chunks using the new async API + embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config) - if isinstance(embedding, dict): - try: - embedding = embedding["data"][0]["embedding"] - except (KeyError, IndexError): - # TODO as a fallback, see if we can find any lists in the payload - raise TypeError( - f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" - ) - # Get or create the default archive for the agent - archive = self.archive_manager.get_or_create_default_archive_for_agent( - agent_id=agent_state.id, agent_name=agent_state.name, actor=actor - ) - - passage = self.create_agent_passage( + passages = [] + for chunk_text, embedding in zip(text_chunks, embeddings): + passage = await self.create_agent_passage_async( PydanticPassage( organization_id=actor.organization_id, archive_id=archive.id, - text=text, + text=chunk_text, embedding=embedding, embedding_config=agent_state.embedding_config, ), @@ -520,84 +507,16 @@ class PassageManager: except Exception as e: raise e - @enforce_types - @trace_method - async def insert_passage_async( - self, - agent_state: AgentState, - text: str, - actor: PydanticUser, - image_ids: Optional[List[str]] = None, - ) -> List[PydanticPassage]: - """Insert passage(s) into archival memory""" - # Get or create default archive for the agent - archive = await self.archive_manager.get_or_create_default_archive_for_agent_async( - agent_id=agent_state.id, - agent_name=agent_state.name, + async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config, actor: PydanticUser) -> List[List[float]]: + """Generate embeddings for all text chunks concurrently using LLMClient""" + + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, actor=actor, ) - archive_id = archive.id - embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size - text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size)) - - if not text_chunks: - return [] - - try: - embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config) - - passages = [ - PydanticPassage( - organization_id=actor.organization_id, - archive_id=archive_id, - text=chunk_text, - embedding=embedding, - embedding_config=agent_state.embedding_config, - ) - for chunk_text, embedding in zip(text_chunks, embeddings) - ] - - passages = await self.create_many_archival_passages_async(passages=passages, actor=actor) - - return passages - - except Exception as e: - raise e - - async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]: - """Generate embeddings for all text chunks concurrently""" - - if embedding_config.embedding_endpoint_type != "openai": - embed_model = embedding_model(embedding_config) - loop = asyncio.get_event_loop() - - tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks] - embeddings = await asyncio.gather(*tasks) - else: - tasks = [ - get_openai_embedding_async( - text, - embedding_config.embedding_model, - embedding_config.embedding_endpoint, - ) - for text in text_chunks - ] - embeddings = await asyncio.gather(*tasks) - - processed_embeddings = [] - for embedding in embeddings: - if isinstance(embedding, dict): - try: - processed_embeddings.append(embedding["data"][0]["embedding"]) - except (KeyError, IndexError): - raise TypeError( - f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" - ) - else: - processed_embeddings.append(embedding) - - return processed_embeddings + embeddings = await embedding_client.request_embeddings(text_chunks, embedding_config) + return embeddings @enforce_types @trace_method diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 40398dcf..35da4c28 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -174,7 +174,7 @@ class LettaCoreToolExecutor(ToolExecutor): Returns: Optional[str]: None is always returned as this function does not produce a response. """ - await PassageManager().insert_passage_async( + await PassageManager().insert_passage( agent_state=agent_state, text=content, actor=actor, diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 2fa78a48..72b91b69 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -13,8 +13,8 @@ logger = logging.getLogger(__name__) from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA -from letta.embeddings import embedding_model from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError +from letta.llm_api.llm_client import LLMClient from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.embedding_config import EmbeddingConfig @@ -91,14 +91,21 @@ def setup_agent( # ====================================================================================================================== -def run_embedding_endpoint(filename): +async def run_embedding_endpoint(filename, actor=None): # load JSON file config_data = json.load(open(filename, "r")) print(config_data) embedding_config = EmbeddingConfig(**config_data) - model = embedding_model(embedding_config) + + # Use the new LLMClient for embeddings + client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + query_text = "hello" - query_vec = model.get_text_embedding(query_text) + query_vecs = await client.request_embeddings([query_text], embedding_config) + query_vec = query_vecs[0] print("vector dim", len(query_vec)) assert query_vec is not None diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 7ba28ee9..5a88de87 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,3 +1,4 @@ +import asyncio import os import threading @@ -96,53 +97,63 @@ def query_in_search_results(search_results, query): return False -def test_archival(agent_obj): +@pytest.mark.asyncio +async def test_archival(agent_obj): """Test archival memory functions comprehensively.""" # Test 1: Basic insertion and retrieval - base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat") - base_functions.archival_memory_insert(agent_obj, "The dog plays in the park") - base_functions.archival_memory_insert(agent_obj, "Python is a programming language") + await base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat") + await asyncio.sleep(0.1) # Small delay to ensure session cleanup + await base_functions.archival_memory_insert(agent_obj, "The dog plays in the park") + await asyncio.sleep(0.1) + await base_functions.archival_memory_insert(agent_obj, "Python is a programming language") + await asyncio.sleep(0.1) # Test exact text search - results, _ = base_functions.archival_memory_search(agent_obj, "cat") + results, _ = await base_functions.archival_memory_search(agent_obj, "cat") assert query_in_search_results(results, "cat") + await asyncio.sleep(0.1) # Test semantic search (should return animal-related content) - results, _ = base_functions.archival_memory_search(agent_obj, "animal pets") + results, _ = await base_functions.archival_memory_search(agent_obj, "animal pets") assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog") + await asyncio.sleep(0.1) # Test unrelated search (should not return animal content) - results, _ = base_functions.archival_memory_search(agent_obj, "programming computers") + results, _ = await base_functions.archival_memory_search(agent_obj, "programming computers") assert query_in_search_results(results, "python") + await asyncio.sleep(0.1) # Test 2: Test pagination # Insert more items to test pagination for i in range(10): - base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}") + await base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}") + await asyncio.sleep(0.05) # Shorter delay for bulk operations # Get first page - page0_results, next_page = base_functions.archival_memory_search(agent_obj, "Test passage", page=0) + page0_results, next_page = await base_functions.archival_memory_search(agent_obj, "Test passage", page=0) + await asyncio.sleep(0.1) # Get second page - page1_results, _ = base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page) + page1_results, _ = await base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page) + await asyncio.sleep(0.1) assert page0_results != page1_results assert query_in_search_results(page0_results, "Test passage") assert query_in_search_results(page1_results, "Test passage") # Test 3: Test complex text patterns - base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John") - base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week") - base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching") + await base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John") + await base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week") + await base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching") # Search for meeting-related content - results, _ = base_functions.archival_memory_search(agent_obj, "meeting schedule") + results, _ = await base_functions.archival_memory_search(agent_obj, "meeting schedule") assert query_in_search_results(results, "meeting") assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week") # Test 4: Test error handling # Test invalid page number try: - base_functions.archival_memory_search(agent_obj, "test", page="invalid") + await base_functions.archival_memory_search(agent_obj, "test", page="invalid") assert False, "Should have raised ValueError" except ValueError: pass diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index fddfafa7..250d3c55 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -408,6 +408,28 @@ def test_send_system_message(client: LettaSDKClient, agent: AgentState): assert send_system_message_response, "Sending message failed" +def test_insert_archival_memory(client: LettaSDKClient, agent: AgentState): + passage = client.agents.passages.create( + agent_id=agent.id, + text="This is a test passage", + ) + assert passage, "Inserting archival memory failed" + + # List archival memory and verify content + archival_memory_response = client.agents.passages.list(agent_id=agent.id, limit=1) + archival_memories = [memory.text for memory in archival_memory_response] + assert "This is a test passage" in archival_memories, f"Retrieving archival memory failed: {archival_memories}" + + # Delete the memory + memory_id_to_delete = archival_memory_response[0].id + client.agents.passages.delete(agent_id=agent.id, memory_id=memory_id_to_delete) + + # Verify memory is gone (implicitly checks that the list call works) + final_passages = client.agents.passages.list(agent_id=agent.id) + passage_texts = [p.text for p in final_passages] + assert "This is a test passage" not in passage_texts, f"Memory was not deleted: {passage_texts}" + + def test_function_return_limit(disable_e2b_api_key, client: LettaSDKClient, agent: AgentState): """Test to see if the function return limit works""" diff --git a/tests/test_server.py b/tests/test_server.py index bc093c12..58cfc1c2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -425,56 +425,6 @@ def test_get_recall_memory(server, org_id, user, agent_id): assert message_id in message_ids, f"{message_id} not in {message_ids}" -# @pytest.mark.order(6) -# def test_get_archival_memory(server, user, agent_id): -# # test archival memory cursor pagination -# actor = user -# -# # List latest 2 passages -# passages_1 = server.agent_manager.list_passages( -# actor=actor, -# agent_id=agent_id, -# ascending=False, -# limit=2, -# ) -# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" -# -# # List next 3 passages (earliest 3) -# cursor1 = passages_1[-1].id -# passages_2 = server.agent_manager.list_passages( -# actor=actor, -# agent_id=agent_id, -# ascending=False, -# before=cursor1, -# ) -# -# # List all 5 -# cursor2 = passages_1[0].created_at -# passages_3 = server.agent_manager.list_passages( -# actor=actor, -# agent_id=agent_id, -# ascending=False, -# end_date=cursor2, -# limit=1000, -# ) -# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test -# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test -# -# latest = passages_1[0] -# earliest = passages_2[-1] -# -# # test archival memory -# passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True) -# assert len(passage_1) == 1 -# assert passage_1[0].text == "alpha" -# passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True) -# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test -# assert all("alpha" not in passage.text for passage in passage_2) -# # test safe empty return -# passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True) -# assert len(passage_none) == 0 - - @pytest.mark.asyncio async def test_get_context_window_overview(server: SyncServer, user, agent_id): """Test that the context window overview fetch works""" diff --git a/tests/test_sources.py b/tests/test_sources.py index 8005ab35..efe2e73a 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -525,7 +525,7 @@ def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKCli # Check it returned successfully tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"] assert len(tool_returns) > 0, "No tool returns found" - assert all(tr.status == "success" for tr in tool_returns), "Tool call failed" + assert all(tr.status == "success" for tr in tool_returns), f"Tool call failed {tr}" def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):