feat: Improve insert archival memory latency (#2435)

Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
This commit is contained in:
Matthew Zhou
2025-05-27 10:28:05 -07:00
committed by GitHub
parent c2a7d8c0ce
commit 20c6bf68ff
2 changed files with 265 additions and 68 deletions

View File

@@ -62,34 +62,7 @@ class PassageManager:
@trace_method
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
# Common fields for both passage types
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
if "agent_id" in data and data["agent_id"]:
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
agent_fields = {
"agent_id": data["agent_id"],
}
passage = AgentPassage(**common_fields, **agent_fields)
elif "source_id" in data and data["source_id"]:
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
source_fields = {
"source_id": data["source_id"],
"file_id": data.get("file_id"),
}
passage = SourcePassage(**common_fields, **source_fields)
else:
raise ValueError("Passage must have either agent_id or source_id")
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
with db_registry.session() as session:
passage.create(session, actor=actor)
@@ -100,6 +73,13 @@ class PassageManager:
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
# Common fields for both passage types
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
async with db_registry.async_session() as session:
passage = await passage.create_async(session, actor=actor)
return passage.to_pydantic()
@trace_method
def _preprocess_passage_for_creation(self, pydantic_passage: PydanticPassage) -> "SqlAlchemyBase":
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
@@ -128,9 +108,7 @@ class PassageManager:
else:
raise ValueError("Passage must have either agent_id or source_id")
async with db_registry.async_session() as session:
passage = await passage.create_async(session, actor=actor)
return passage.to_pydantic()
return passage
@enforce_types
@trace_method
@@ -142,7 +120,28 @@ class PassageManager:
@trace_method
async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple passages."""
return await asyncio.gather(*[self.create_passage_async(p, actor) for p in passages])
async with db_registry.async_session() as session:
agent_passages = []
source_passages = []
for p in passages:
model = self._preprocess_passage_for_creation(p)
if isinstance(model, AgentPassage):
agent_passages.append(model)
elif isinstance(model, SourcePassage):
source_passages.append(model)
else:
raise TypeError(f"Unexpected passage type: {type(model)}")
results = []
if agent_passages:
agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
results.extend(agent_created)
if source_passages:
source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
results.extend(source_created)
return [p.to_pydantic() for p in results]
@enforce_types
@trace_method
@@ -215,53 +214,66 @@ class PassageManager:
"""Insert passage(s) into archival memory"""
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
# TODO eventually migrate off of llama-index for embeddings?
# Already causing pain for OpenAI proxy endpoints like LM Studio...
if agent_state.embedding_config.embedding_endpoint_type != "openai":
embed_model = embedding_model(agent_state.embedding_config)
passages = []
if not text_chunks:
return []
try:
# breakup string into passages
for text in parse_and_chunk_text(text, embedding_chunk_size):
embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config)
if agent_state.embedding_config.embedding_endpoint_type != "openai":
embedding = embed_model.get_text_embedding(text)
else:
# TODO should have the settings passed in via the server call
embedding = await get_openai_embedding_async(
text,
agent_state.embedding_config.embedding_model,
agent_state.embedding_config.embedding_endpoint,
)
if isinstance(embedding, dict):
try:
embedding = embedding["data"][0]["embedding"]
except (KeyError, IndexError):
# TODO as a fallback, see if we can find any lists in the payload
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passage = await self.create_passage_async(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
text=text,
embedding=embedding,
embedding_config=agent_state.embedding_config,
),
actor=actor,
passages = [
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
text=chunk_text,
embedding=embedding,
embedding_config=agent_state.embedding_config,
)
passages.append(passage)
for chunk_text, embedding in zip(text_chunks, embeddings)
]
passages = await self.create_many_passages_async(passages=passages, actor=actor)
return passages
except Exception as e:
raise e
async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]:
"""Generate embeddings for all text chunks concurrently"""
if embedding_config.embedding_endpoint_type != "openai":
embed_model = embedding_model(embedding_config)
loop = asyncio.get_event_loop()
tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks]
embeddings = await asyncio.gather(*tasks)
else:
tasks = [
get_openai_embedding_async(
text,
embedding_config.embedding_model,
embedding_config.embedding_endpoint,
)
for text in text_chunks
]
embeddings = await asyncio.gather(*tasks)
processed_embeddings = []
for embedding in embeddings:
if isinstance(embedding, dict):
try:
processed_embeddings.append(embedding["data"][0]["embedding"])
except (KeyError, IndexError):
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
else:
processed_embeddings.append(embedding)
return processed_embeddings
@enforce_types
@trace_method
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:

View File

@@ -0,0 +1,185 @@
import asyncio
import logging
import os
import threading
import time
import uuid
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
from dotenv import load_dotenv
from faker import Faker
from letta_client import AsyncLetta
from tqdm import tqdm
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
# --- Server Management --- #
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(2) # Allow server startup time
return url
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = AsyncLetta(base_url=server_url)
yield client
# --- Load Test --- #
NUM_AGENTS = 30
@pytest.mark.asyncio
async def test_insert_archival_memories_concurrent(client):
fake = Faker()
# 1) Create agents
agent_ids = []
for i in tqdm(range(NUM_AGENTS), desc="Creating agents"):
agent = await client.agents.create(
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
include_base_tools=True,
memory_blocks=[
{"label": "human", "value": "Name: Matt"},
{"label": "persona", "value": "Friendly agent"},
],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
agent_ids.append(agent.id)
# 2) Measure start and duration of each call
timeline = []
async def measure(agent_index, aid):
t0 = time.perf_counter()
await client.agents.passages.create(agent_id=aid, text=fake.paragraph())
t1 = time.perf_counter()
timeline.append((agent_index, t0, t1 - t0))
await asyncio.gather(*(measure(idx, aid) for idx, aid in enumerate(agent_ids)))
# 3) Convert to arrays
timeline.sort(key=lambda x: x[0])
indices = np.array([t[0] for t in timeline])
starts = np.array([t[1] for t in timeline])
durs = np.array([t[2] for t in timeline])
start_offset = starts - starts.min()
print(f"Latency stats (s): min={durs.min():.3f}, mean={durs.mean():.3f}, " f"max={durs.max():.3f}, std={durs.std():.3f}")
# 4) Generate improved plots
# Helper: concurrency over time
events = np.concatenate([np.column_stack([starts, np.ones_like(starts)]), np.column_stack([starts + durs, -np.ones_like(durs)])])
events = events[events[:, 0].argsort()]
concurrency_t = np.cumsum(events[:, 1])
concurrency_x = events[:, 0] - starts.min()
# Helper: latency CDF
durs_sorted = np.sort(durs)
cdf_y = np.arange(1, len(durs_sorted) + 1) / len(durs_sorted)
# Plot all 6 subplots
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axs = axes.ravel()
# 1) Kickoff timeline
axs[0].scatter(indices, start_offset, s=15)
axs[0].set_title("Kick-off timeline")
axs[0].set_xlabel("Call index")
axs[0].set_ylabel("Start offset (s)")
# 2) Per-call latency
axs[1].plot(indices, durs, marker="o", linestyle="")
axs[1].set_title("Per-call latency")
axs[1].set_xlabel("Call index")
axs[1].set_ylabel("Duration (s)")
# 3) Latency distribution (histogram)
axs[2].hist(durs, bins="auto")
axs[2].set_title("Latency distribution")
axs[2].set_xlabel("Duration (s)")
axs[2].set_ylabel("Count")
# 4) Empirical CDF
axs[3].step(durs_sorted, cdf_y, where="post")
axs[3].set_title("Latency CDF")
axs[3].set_xlabel("Duration (s)")
axs[3].set_ylabel("Fraction ≤ x")
# 5) Concurrency over time
axs[4].step(concurrency_x, concurrency_t, where="post")
axs[4].set_title("Concurrency vs. time")
axs[4].set_xlabel("Time since first start (s)")
axs[4].set_ylabel("# in-flight")
# 6) Summary stats
axs[5].axis("off")
summary_text = (
f"n = {len(durs)}\n"
f"min = {durs.min():.3f} s\n"
f"p50 = {np.percentile(durs, 50):.3f} s\n"
f"mean = {durs.mean():.3f} s\n"
f"p95 = {np.percentile(durs, 95):.3f} s\n"
f"max = {durs.max():.3f} s\n"
f"stdev = {durs.std():.3f} s"
)
axs[5].text(0.02, 0.98, summary_text, va="top", ha="left", fontsize=11, family="monospace", transform=axs[5].transAxes)
plt.tight_layout()
plt.savefig("latency_diagnostics.png", dpi=150)
print("Saved latency_diagnostics.png")
@pytest.mark.asyncio
async def test_insert_large_archival_memory(client):
# 1) Create 30 agents
agent = await client.agents.create(
include_base_tools=True,
memory_blocks=[
{"label": "human", "value": "Name: Matt"},
{"label": "persona", "value": "Friendly agent"},
],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
file_path = Path(__file__).parent / "data" / "paper1.txt"
text = file_path.read_text()
t0 = time.perf_counter()
await client.agents.passages.create(agent_id=agent.id, text=text)
t1 = time.perf_counter()
print(f"Total time: {t1-t0}")