diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 1b801de2..f8139955 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -62,34 +62,7 @@ class PassageManager: @trace_method def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" - # Common fields for both passage types - data = pydantic_passage.model_dump(to_orm=True) - common_fields = { - "id": data.get("id"), - "text": data["text"], - "embedding": data["embedding"], - "embedding_config": data["embedding_config"], - "organization_id": data["organization_id"], - "metadata_": data.get("metadata", {}), - "is_deleted": data.get("is_deleted", False), - "created_at": data.get("created_at", datetime.now(timezone.utc)), - } - - if "agent_id" in data and data["agent_id"]: - assert not data.get("source_id"), "Passage cannot have both agent_id and source_id" - agent_fields = { - "agent_id": data["agent_id"], - } - passage = AgentPassage(**common_fields, **agent_fields) - elif "source_id" in data and data["source_id"]: - assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id" - source_fields = { - "source_id": data["source_id"], - "file_id": data.get("file_id"), - } - passage = SourcePassage(**common_fields, **source_fields) - else: - raise ValueError("Passage must have either agent_id or source_id") + passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage) with db_registry.session() as session: passage.create(session, actor=actor) @@ -100,6 +73,13 @@ class PassageManager: async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" # Common fields for both passage types + passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage) + async with db_registry.async_session() as session: + passage = await passage.create_async(session, actor=actor) + return passage.to_pydantic() + + @trace_method + def _preprocess_passage_for_creation(self, pydantic_passage: PydanticPassage) -> "SqlAlchemyBase": data = pydantic_passage.model_dump(to_orm=True) common_fields = { "id": data.get("id"), @@ -128,9 +108,7 @@ class PassageManager: else: raise ValueError("Passage must have either agent_id or source_id") - async with db_registry.async_session() as session: - passage = await passage.create_async(session, actor=actor) - return passage.to_pydantic() + return passage @enforce_types @trace_method @@ -142,7 +120,28 @@ class PassageManager: @trace_method async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: """Create multiple passages.""" - return await asyncio.gather(*[self.create_passage_async(p, actor) for p in passages]) + async with db_registry.async_session() as session: + agent_passages = [] + source_passages = [] + + for p in passages: + model = self._preprocess_passage_for_creation(p) + if isinstance(model, AgentPassage): + agent_passages.append(model) + elif isinstance(model, SourcePassage): + source_passages.append(model) + else: + raise TypeError(f"Unexpected passage type: {type(model)}") + + results = [] + if agent_passages: + agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor) + results.extend(agent_created) + if source_passages: + source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor) + results.extend(source_created) + + return [p.to_pydantic() for p in results] @enforce_types @trace_method @@ -215,53 +214,66 @@ class PassageManager: """Insert passage(s) into archival memory""" embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size + text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size)) - # TODO eventually migrate off of llama-index for embeddings? - # Already causing pain for OpenAI proxy endpoints like LM Studio... - if agent_state.embedding_config.embedding_endpoint_type != "openai": - embed_model = embedding_model(agent_state.embedding_config) - - passages = [] + if not text_chunks: + return [] try: - # breakup string into passages - for text in parse_and_chunk_text(text, embedding_chunk_size): + embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config) - if agent_state.embedding_config.embedding_endpoint_type != "openai": - embedding = embed_model.get_text_embedding(text) - else: - # TODO should have the settings passed in via the server call - embedding = await get_openai_embedding_async( - text, - agent_state.embedding_config.embedding_model, - agent_state.embedding_config.embedding_endpoint, - ) - - if isinstance(embedding, dict): - try: - embedding = embedding["data"][0]["embedding"] - except (KeyError, IndexError): - # TODO as a fallback, see if we can find any lists in the payload - raise TypeError( - f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" - ) - passage = await self.create_passage_async( - PydanticPassage( - organization_id=actor.organization_id, - agent_id=agent_id, - text=text, - embedding=embedding, - embedding_config=agent_state.embedding_config, - ), - actor=actor, + passages = [ + PydanticPassage( + organization_id=actor.organization_id, + agent_id=agent_id, + text=chunk_text, + embedding=embedding, + embedding_config=agent_state.embedding_config, ) - passages.append(passage) + for chunk_text, embedding in zip(text_chunks, embeddings) + ] + + passages = await self.create_many_passages_async(passages=passages, actor=actor) return passages except Exception as e: raise e + async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]: + """Generate embeddings for all text chunks concurrently""" + + if embedding_config.embedding_endpoint_type != "openai": + embed_model = embedding_model(embedding_config) + loop = asyncio.get_event_loop() + + tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks] + embeddings = await asyncio.gather(*tasks) + else: + tasks = [ + get_openai_embedding_async( + text, + embedding_config.embedding_model, + embedding_config.embedding_endpoint, + ) + for text in text_chunks + ] + embeddings = await asyncio.gather(*tasks) + + processed_embeddings = [] + for embedding in embeddings: + if isinstance(embedding, dict): + try: + processed_embeddings.append(embedding["data"][0]["embedding"]) + except (KeyError, IndexError): + raise TypeError( + f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" + ) + else: + processed_embeddings.append(embedding) + + return processed_embeddings + @enforce_types @trace_method def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]: diff --git a/performance_tests/test_insert_archival_memory.py b/performance_tests/test_insert_archival_memory.py new file mode 100644 index 00000000..4ce29664 --- /dev/null +++ b/performance_tests/test_insert_archival_memory.py @@ -0,0 +1,185 @@ +import asyncio +import logging +import os +import threading +import time +import uuid +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from dotenv import load_dotenv +from faker import Faker +from letta_client import AsyncLetta +from tqdm import tqdm + +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig + +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("httpcore").setLevel(logging.WARNING) + + +# --- Server Management --- # + + +def _run_server(): + """Starts the Letta server in a background thread.""" + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + +@pytest.fixture(scope="session") +def server_url(): + """Ensures a server is running and returns its base URL.""" + url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + time.sleep(2) # Allow server startup time + + return url + + +# --- Client Setup --- # + + +@pytest.fixture(scope="session") +def client(server_url): + """Creates a REST client for testing.""" + client = AsyncLetta(base_url=server_url) + yield client + + +# --- Load Test --- # + +NUM_AGENTS = 30 + + +@pytest.mark.asyncio +async def test_insert_archival_memories_concurrent(client): + fake = Faker() + + # 1) Create agents + agent_ids = [] + for i in tqdm(range(NUM_AGENTS), desc="Creating agents"): + agent = await client.agents.create( + name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}", + include_base_tools=True, + memory_blocks=[ + {"label": "human", "value": "Name: Matt"}, + {"label": "persona", "value": "Friendly agent"}, + ], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + agent_ids.append(agent.id) + + # 2) Measure start and duration of each call + timeline = [] + + async def measure(agent_index, aid): + t0 = time.perf_counter() + await client.agents.passages.create(agent_id=aid, text=fake.paragraph()) + t1 = time.perf_counter() + timeline.append((agent_index, t0, t1 - t0)) + + await asyncio.gather(*(measure(idx, aid) for idx, aid in enumerate(agent_ids))) + + # 3) Convert to arrays + timeline.sort(key=lambda x: x[0]) + indices = np.array([t[0] for t in timeline]) + starts = np.array([t[1] for t in timeline]) + durs = np.array([t[2] for t in timeline]) + start_offset = starts - starts.min() + + print(f"Latency stats (s): min={durs.min():.3f}, mean={durs.mean():.3f}, " f"max={durs.max():.3f}, std={durs.std():.3f}") + + # 4) Generate improved plots + # Helper: concurrency over time + events = np.concatenate([np.column_stack([starts, np.ones_like(starts)]), np.column_stack([starts + durs, -np.ones_like(durs)])]) + events = events[events[:, 0].argsort()] + concurrency_t = np.cumsum(events[:, 1]) + concurrency_x = events[:, 0] - starts.min() + + # Helper: latency CDF + durs_sorted = np.sort(durs) + cdf_y = np.arange(1, len(durs_sorted) + 1) / len(durs_sorted) + + # Plot all 6 subplots + fig, axes = plt.subplots(2, 3, figsize=(15, 8)) + axs = axes.ravel() + + # 1) Kickoff timeline + axs[0].scatter(indices, start_offset, s=15) + axs[0].set_title("Kick-off timeline") + axs[0].set_xlabel("Call index") + axs[0].set_ylabel("Start offset (s)") + + # 2) Per-call latency + axs[1].plot(indices, durs, marker="o", linestyle="") + axs[1].set_title("Per-call latency") + axs[1].set_xlabel("Call index") + axs[1].set_ylabel("Duration (s)") + + # 3) Latency distribution (histogram) + axs[2].hist(durs, bins="auto") + axs[2].set_title("Latency distribution") + axs[2].set_xlabel("Duration (s)") + axs[2].set_ylabel("Count") + + # 4) Empirical CDF + axs[3].step(durs_sorted, cdf_y, where="post") + axs[3].set_title("Latency CDF") + axs[3].set_xlabel("Duration (s)") + axs[3].set_ylabel("Fraction ≤ x") + + # 5) Concurrency over time + axs[4].step(concurrency_x, concurrency_t, where="post") + axs[4].set_title("Concurrency vs. time") + axs[4].set_xlabel("Time since first start (s)") + axs[4].set_ylabel("# in-flight") + + # 6) Summary stats + axs[5].axis("off") + summary_text = ( + f"n = {len(durs)}\n" + f"min = {durs.min():.3f} s\n" + f"p50 = {np.percentile(durs, 50):.3f} s\n" + f"mean = {durs.mean():.3f} s\n" + f"p95 = {np.percentile(durs, 95):.3f} s\n" + f"max = {durs.max():.3f} s\n" + f"stdev = {durs.std():.3f} s" + ) + axs[5].text(0.02, 0.98, summary_text, va="top", ha="left", fontsize=11, family="monospace", transform=axs[5].transAxes) + + plt.tight_layout() + plt.savefig("latency_diagnostics.png", dpi=150) + print("Saved latency_diagnostics.png") + + +@pytest.mark.asyncio +async def test_insert_large_archival_memory(client): + # 1) Create 30 agents + agent = await client.agents.create( + include_base_tools=True, + memory_blocks=[ + {"label": "human", "value": "Name: Matt"}, + {"label": "persona", "value": "Friendly agent"}, + ], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + + file_path = Path(__file__).parent / "data" / "paper1.txt" + text = file_path.read_text() + + t0 = time.perf_counter() + await client.agents.passages.create(agent_id=agent.id, text=text) + t1 = time.perf_counter() + + print(f"Total time: {t1-t0}")