feat: Improve insert archival memory latency (#2435)
Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
This commit is contained in:
@@ -62,34 +62,7 @@ class PassageManager:
|
||||
@trace_method
|
||||
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
# Common fields for both passage types
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
|
||||
if "agent_id" in data and data["agent_id"]:
|
||||
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
|
||||
agent_fields = {
|
||||
"agent_id": data["agent_id"],
|
||||
}
|
||||
passage = AgentPassage(**common_fields, **agent_fields)
|
||||
elif "source_id" in data and data["source_id"]:
|
||||
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
|
||||
source_fields = {
|
||||
"source_id": data["source_id"],
|
||||
"file_id": data.get("file_id"),
|
||||
}
|
||||
passage = SourcePassage(**common_fields, **source_fields)
|
||||
else:
|
||||
raise ValueError("Passage must have either agent_id or source_id")
|
||||
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
|
||||
|
||||
with db_registry.session() as session:
|
||||
passage.create(session, actor=actor)
|
||||
@@ -100,6 +73,13 @@ class PassageManager:
|
||||
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
# Common fields for both passage types
|
||||
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
|
||||
async with db_registry.async_session() as session:
|
||||
passage = await passage.create_async(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
def _preprocess_passage_for_creation(self, pydantic_passage: PydanticPassage) -> "SqlAlchemyBase":
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
@@ -128,9 +108,7 @@ class PassageManager:
|
||||
else:
|
||||
raise ValueError("Passage must have either agent_id or source_id")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
passage = await passage.create_async(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
return passage
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -142,7 +120,28 @@ class PassageManager:
|
||||
@trace_method
|
||||
async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple passages."""
|
||||
return await asyncio.gather(*[self.create_passage_async(p, actor) for p in passages])
|
||||
async with db_registry.async_session() as session:
|
||||
agent_passages = []
|
||||
source_passages = []
|
||||
|
||||
for p in passages:
|
||||
model = self._preprocess_passage_for_creation(p)
|
||||
if isinstance(model, AgentPassage):
|
||||
agent_passages.append(model)
|
||||
elif isinstance(model, SourcePassage):
|
||||
source_passages.append(model)
|
||||
else:
|
||||
raise TypeError(f"Unexpected passage type: {type(model)}")
|
||||
|
||||
results = []
|
||||
if agent_passages:
|
||||
agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
|
||||
results.extend(agent_created)
|
||||
if source_passages:
|
||||
source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
|
||||
results.extend(source_created)
|
||||
|
||||
return [p.to_pydantic() for p in results]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -215,53 +214,66 @@ class PassageManager:
|
||||
"""Insert passage(s) into archival memory"""
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
|
||||
|
||||
# TODO eventually migrate off of llama-index for embeddings?
|
||||
# Already causing pain for OpenAI proxy endpoints like LM Studio...
|
||||
if agent_state.embedding_config.embedding_endpoint_type != "openai":
|
||||
embed_model = embedding_model(agent_state.embedding_config)
|
||||
|
||||
passages = []
|
||||
if not text_chunks:
|
||||
return []
|
||||
|
||||
try:
|
||||
# breakup string into passages
|
||||
for text in parse_and_chunk_text(text, embedding_chunk_size):
|
||||
embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config)
|
||||
|
||||
if agent_state.embedding_config.embedding_endpoint_type != "openai":
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
else:
|
||||
# TODO should have the settings passed in via the server call
|
||||
embedding = await get_openai_embedding_async(
|
||||
text,
|
||||
agent_state.embedding_config.embedding_model,
|
||||
agent_state.embedding_config.embedding_endpoint,
|
||||
)
|
||||
|
||||
if isinstance(embedding, dict):
|
||||
try:
|
||||
embedding = embedding["data"][0]["embedding"]
|
||||
except (KeyError, IndexError):
|
||||
# TODO as a fallback, see if we can find any lists in the payload
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passage = await self.create_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
),
|
||||
actor=actor,
|
||||
passages = [
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
text=chunk_text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
)
|
||||
passages.append(passage)
|
||||
for chunk_text, embedding in zip(text_chunks, embeddings)
|
||||
]
|
||||
|
||||
passages = await self.create_many_passages_async(passages=passages, actor=actor)
|
||||
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]:
|
||||
"""Generate embeddings for all text chunks concurrently"""
|
||||
|
||||
if embedding_config.embedding_endpoint_type != "openai":
|
||||
embed_model = embedding_model(embedding_config)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks]
|
||||
embeddings = await asyncio.gather(*tasks)
|
||||
else:
|
||||
tasks = [
|
||||
get_openai_embedding_async(
|
||||
text,
|
||||
embedding_config.embedding_model,
|
||||
embedding_config.embedding_endpoint,
|
||||
)
|
||||
for text in text_chunks
|
||||
]
|
||||
embeddings = await asyncio.gather(*tasks)
|
||||
|
||||
processed_embeddings = []
|
||||
for embedding in embeddings:
|
||||
if isinstance(embedding, dict):
|
||||
try:
|
||||
processed_embeddings.append(embedding["data"][0]["embedding"])
|
||||
except (KeyError, IndexError):
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
else:
|
||||
processed_embeddings.append(embedding)
|
||||
|
||||
return processed_embeddings
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
|
||||
|
||||
185
performance_tests/test_insert_archival_memory.py
Normal file
185
performance_tests/test_insert_archival_memory.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from faker import Faker
|
||||
from letta_client import AsyncLetta
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(2) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# --- Client Setup --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = AsyncLetta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
# --- Load Test --- #
|
||||
|
||||
NUM_AGENTS = 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_archival_memories_concurrent(client):
|
||||
fake = Faker()
|
||||
|
||||
# 1) Create agents
|
||||
agent_ids = []
|
||||
for i in tqdm(range(NUM_AGENTS), desc="Creating agents"):
|
||||
agent = await client.agents.create(
|
||||
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
|
||||
include_base_tools=True,
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "Name: Matt"},
|
||||
{"label": "persona", "value": "Friendly agent"},
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
agent_ids.append(agent.id)
|
||||
|
||||
# 2) Measure start and duration of each call
|
||||
timeline = []
|
||||
|
||||
async def measure(agent_index, aid):
|
||||
t0 = time.perf_counter()
|
||||
await client.agents.passages.create(agent_id=aid, text=fake.paragraph())
|
||||
t1 = time.perf_counter()
|
||||
timeline.append((agent_index, t0, t1 - t0))
|
||||
|
||||
await asyncio.gather(*(measure(idx, aid) for idx, aid in enumerate(agent_ids)))
|
||||
|
||||
# 3) Convert to arrays
|
||||
timeline.sort(key=lambda x: x[0])
|
||||
indices = np.array([t[0] for t in timeline])
|
||||
starts = np.array([t[1] for t in timeline])
|
||||
durs = np.array([t[2] for t in timeline])
|
||||
start_offset = starts - starts.min()
|
||||
|
||||
print(f"Latency stats (s): min={durs.min():.3f}, mean={durs.mean():.3f}, " f"max={durs.max():.3f}, std={durs.std():.3f}")
|
||||
|
||||
# 4) Generate improved plots
|
||||
# Helper: concurrency over time
|
||||
events = np.concatenate([np.column_stack([starts, np.ones_like(starts)]), np.column_stack([starts + durs, -np.ones_like(durs)])])
|
||||
events = events[events[:, 0].argsort()]
|
||||
concurrency_t = np.cumsum(events[:, 1])
|
||||
concurrency_x = events[:, 0] - starts.min()
|
||||
|
||||
# Helper: latency CDF
|
||||
durs_sorted = np.sort(durs)
|
||||
cdf_y = np.arange(1, len(durs_sorted) + 1) / len(durs_sorted)
|
||||
|
||||
# Plot all 6 subplots
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
axs = axes.ravel()
|
||||
|
||||
# 1) Kickoff timeline
|
||||
axs[0].scatter(indices, start_offset, s=15)
|
||||
axs[0].set_title("Kick-off timeline")
|
||||
axs[0].set_xlabel("Call index")
|
||||
axs[0].set_ylabel("Start offset (s)")
|
||||
|
||||
# 2) Per-call latency
|
||||
axs[1].plot(indices, durs, marker="o", linestyle="")
|
||||
axs[1].set_title("Per-call latency")
|
||||
axs[1].set_xlabel("Call index")
|
||||
axs[1].set_ylabel("Duration (s)")
|
||||
|
||||
# 3) Latency distribution (histogram)
|
||||
axs[2].hist(durs, bins="auto")
|
||||
axs[2].set_title("Latency distribution")
|
||||
axs[2].set_xlabel("Duration (s)")
|
||||
axs[2].set_ylabel("Count")
|
||||
|
||||
# 4) Empirical CDF
|
||||
axs[3].step(durs_sorted, cdf_y, where="post")
|
||||
axs[3].set_title("Latency CDF")
|
||||
axs[3].set_xlabel("Duration (s)")
|
||||
axs[3].set_ylabel("Fraction ≤ x")
|
||||
|
||||
# 5) Concurrency over time
|
||||
axs[4].step(concurrency_x, concurrency_t, where="post")
|
||||
axs[4].set_title("Concurrency vs. time")
|
||||
axs[4].set_xlabel("Time since first start (s)")
|
||||
axs[4].set_ylabel("# in-flight")
|
||||
|
||||
# 6) Summary stats
|
||||
axs[5].axis("off")
|
||||
summary_text = (
|
||||
f"n = {len(durs)}\n"
|
||||
f"min = {durs.min():.3f} s\n"
|
||||
f"p50 = {np.percentile(durs, 50):.3f} s\n"
|
||||
f"mean = {durs.mean():.3f} s\n"
|
||||
f"p95 = {np.percentile(durs, 95):.3f} s\n"
|
||||
f"max = {durs.max():.3f} s\n"
|
||||
f"stdev = {durs.std():.3f} s"
|
||||
)
|
||||
axs[5].text(0.02, 0.98, summary_text, va="top", ha="left", fontsize=11, family="monospace", transform=axs[5].transAxes)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("latency_diagnostics.png", dpi=150)
|
||||
print("Saved latency_diagnostics.png")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_large_archival_memory(client):
|
||||
# 1) Create 30 agents
|
||||
agent = await client.agents.create(
|
||||
include_base_tools=True,
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "Name: Matt"},
|
||||
{"label": "persona", "value": "Friendly agent"},
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
|
||||
file_path = Path(__file__).parent / "data" / "paper1.txt"
|
||||
text = file_path.read_text()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
await client.agents.passages.create(agent_id=agent.id, text=text)
|
||||
t1 = time.perf_counter()
|
||||
|
||||
print(f"Total time: {t1-t0}")
|
||||
Reference in New Issue
Block a user