From fd7c8193fef770aa3264e047df0e3bd5b4912b20 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 5 Nov 2025 17:55:06 -0800 Subject: [PATCH] feat: remove chunking for archival memory [LET-6080] (#5997) * feat: remove chunking for archival memory * add error and tests --- letta/embeddings.py | 53 ------------------------------- letta/server/server.py | 11 +++++++ letta/services/passage_manager.py | 4 +-- letta/settings.py | 3 ++ tests/test_sdk_client.py | 20 ++++++++++++ 5 files changed, 36 insertions(+), 55 deletions(-) delete mode 100644 letta/embeddings.py diff --git a/letta/embeddings.py b/letta/embeddings.py deleted file mode 100644 index a07e16c3..00000000 --- a/letta/embeddings.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import List - -import tiktoken - -from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP -from letta.utils import printd - - -def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]: - from llama_index.core import Document as LlamaIndexDocument - from llama_index.core.node_parser import SentenceSplitter - - parser = SentenceSplitter(chunk_size=chunk_size) - llama_index_docs = [LlamaIndexDocument(text=text)] - nodes = parser.get_nodes_from_documents(llama_index_docs) - return [n.text for n in nodes] - - -def truncate_text(text: str, max_length: int, encoding) -> str: - # truncate the text based on max_length and encoding - encoded_text = encoding.encode(text)[:max_length] - return encoding.decode(encoded_text) - - -def check_and_split_text(text: str, embedding_model: str) -> List[str]: - """Split text into chunks of max_length tokens or less""" - - if embedding_model in EMBEDDING_TO_TOKENIZER_MAP: - encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_MAP[embedding_model]) - else: - print(f"Warning: couldn't find tokenizer for model {embedding_model}, using default tokenizer {EMBEDDING_TO_TOKENIZER_DEFAULT}") - encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_DEFAULT) - - num_tokens = len(encoding.encode(text)) - - # determine max length - if hasattr(encoding, "max_length"): - # TODO(fix) this is broken - max_length = encoding.max_length - else: - # TODO: figure out the real number - printd(f"Warning: couldn't find max_length for tokenizer {embedding_model}, using default max_length 8191") - max_length = 8191 - - # truncate text if too long - if num_tokens > max_length: - print(f"Warning: text is too long ({num_tokens} tokens), truncating to {max_length} tokens.") - # First, apply any necessary formatting - formatted_text = format_text(text, embedding_model) - # Then truncate - text = truncate_text(formatted_text, max_length, encoding) - - return [text] diff --git a/letta/server/server.py b/letta/server/server.py index 1b1d020d..eef0ff20 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -634,6 +634,17 @@ class SyncServer(object): async def insert_archival_memory_async( self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime] ) -> List[Passage]: + from letta.settings import settings + from letta.utils import count_tokens + + # Check token count against limit + token_count = count_tokens(memory_contents) + if token_count > settings.archival_memory_token_limit: + raise LettaInvalidArgumentError( + message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)", + argument_name="memory_contents", + ) + # Get the agent object (loaded in memory) agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 7d8e3ee7..1003ba8a 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -8,7 +8,6 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from letta.constants import MAX_EMBEDDING_DIM -from letta.embeddings import parse_and_chunk_text from letta.helpers.decorators import async_redis_cache from letta.llm_api.llm_client import LLMClient from letta.log import get_logger @@ -427,7 +426,8 @@ class PassageManager: # Get or create the default archive for the agent archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor) - text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size)) + # TODO: check to make sure token count is okay for embedding model + text_chunks = [text] if not text_chunks: return [] diff --git a/letta/settings.py b/letta/settings.py index 0b02d0dd..9147f53f 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -339,6 +339,9 @@ class Settings(BaseSettings): # enabling letta_agent_v1 architecture use_letta_v1_agent: bool = False + # Archival memory token limit + archival_memory_token_limit: int = 8192 + @property def letta_pg_uri(self) -> str: if self.pg_uri: diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 97e7f5d8..5679cf3d 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -636,6 +636,26 @@ def test_insert_archival_memory(client: LettaSDKClient, agent: AgentState): assert "This is a test passage" not in passage_texts, f"Memory was not deleted: {passage_texts}" +def test_insert_archival_memory_exceeds_token_limit(client: LettaSDKClient, agent: AgentState): + """Test that inserting archival memory exceeding token limit raises an error.""" + from letta.settings import settings + + # Create a text that exceeds the token limit (default 8192) + # Each word is roughly 1-2 tokens, so we'll create a large enough text + long_text = " ".join(["word"] * (settings.archival_memory_token_limit + 1000)) + + # Attempt to insert and expect an error + with pytest.raises(ApiError) as exc_info: + client.agents.passages.create( + agent_id=agent.id, + text=long_text, + ) + + # Verify the error is an INVALID_ARGUMENT error + assert exc_info.value.status_code == 400, f"Expected 400 status code, got {exc_info.value.status_code}" + assert "token limit" in str(exc_info.value).lower(), f"Error message should mention token limit: {exc_info.value}" + + def test_search_archival_memory(client: LettaSDKClient, agent: AgentState): from datetime import datetime, timezone