feat: remove chunking for archival memory [LET-6080] (#5997)

* feat: remove chunking for archival memory

* add error and tests
This commit is contained in:
Sarah Wooders
2025-11-05 17:55:06 -08:00
committed by Caren Thomas
parent 61b1a7f600
commit fd7c8193fe
5 changed files with 36 additions and 55 deletions

View File

@@ -1,53 +0,0 @@
from typing import List
import tiktoken
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP
from letta.utils import printd
def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]:
from llama_index.core import Document as LlamaIndexDocument
from llama_index.core.node_parser import SentenceSplitter
parser = SentenceSplitter(chunk_size=chunk_size)
llama_index_docs = [LlamaIndexDocument(text=text)]
nodes = parser.get_nodes_from_documents(llama_index_docs)
return [n.text for n in nodes]
def truncate_text(text: str, max_length: int, encoding) -> str:
# truncate the text based on max_length and encoding
encoded_text = encoding.encode(text)[:max_length]
return encoding.decode(encoded_text)
def check_and_split_text(text: str, embedding_model: str) -> List[str]:
"""Split text into chunks of max_length tokens or less"""
if embedding_model in EMBEDDING_TO_TOKENIZER_MAP:
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_MAP[embedding_model])
else:
print(f"Warning: couldn't find tokenizer for model {embedding_model}, using default tokenizer {EMBEDDING_TO_TOKENIZER_DEFAULT}")
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_DEFAULT)
num_tokens = len(encoding.encode(text))
# determine max length
if hasattr(encoding, "max_length"):
# TODO(fix) this is broken
max_length = encoding.max_length
else:
# TODO: figure out the real number
printd(f"Warning: couldn't find max_length for tokenizer {embedding_model}, using default max_length 8191")
max_length = 8191
# truncate text if too long
if num_tokens > max_length:
print(f"Warning: text is too long ({num_tokens} tokens), truncating to {max_length} tokens.")
# First, apply any necessary formatting
formatted_text = format_text(text, embedding_model)
# Then truncate
text = truncate_text(formatted_text, max_length, encoding)
return [text]

View File

@@ -634,6 +634,17 @@ class SyncServer(object):
async def insert_archival_memory_async(
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
) -> List[Passage]:
from letta.settings import settings
from letta.utils import count_tokens
# Check token count against limit
token_count = count_tokens(memory_contents)
if token_count > settings.archival_memory_token_limit:
raise LettaInvalidArgumentError(
message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)",
argument_name="memory_contents",
)
# Get the agent object (loaded in memory)
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)

View File

@@ -8,7 +8,6 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import parse_and_chunk_text
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
@@ -427,7 +426,8 @@ class PassageManager:
# Get or create the default archive for the agent
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
# TODO: check to make sure token count is okay for embedding model
text_chunks = [text]
if not text_chunks:
return []

View File

@@ -339,6 +339,9 @@ class Settings(BaseSettings):
# enabling letta_agent_v1 architecture
use_letta_v1_agent: bool = False
# Archival memory token limit
archival_memory_token_limit: int = 8192
@property
def letta_pg_uri(self) -> str:
if self.pg_uri:

View File

@@ -636,6 +636,26 @@ def test_insert_archival_memory(client: LettaSDKClient, agent: AgentState):
assert "This is a test passage" not in passage_texts, f"Memory was not deleted: {passage_texts}"
def test_insert_archival_memory_exceeds_token_limit(client: LettaSDKClient, agent: AgentState):
"""Test that inserting archival memory exceeding token limit raises an error."""
from letta.settings import settings
# Create a text that exceeds the token limit (default 8192)
# Each word is roughly 1-2 tokens, so we'll create a large enough text
long_text = " ".join(["word"] * (settings.archival_memory_token_limit + 1000))
# Attempt to insert and expect an error
with pytest.raises(ApiError) as exc_info:
client.agents.passages.create(
agent_id=agent.id,
text=long_text,
)
# Verify the error is an INVALID_ARGUMENT error
assert exc_info.value.status_code == 400, f"Expected 400 status code, got {exc_info.value.status_code}"
assert "token limit" in str(exc_info.value).lower(), f"Error message should mention token limit: {exc_info.value}"
def test_search_archival_memory(client: LettaSDKClient, agent: AgentState):
from datetime import datetime, timezone