feat: remove chunking for archival memory [LET-6080] (#5997)
* feat: remove chunking for archival memory * add error and tests
This commit is contained in:
committed by
Caren Thomas
parent
61b1a7f600
commit
fd7c8193fe
@@ -1,53 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import tiktoken
|
||||
|
||||
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP
|
||||
from letta.utils import printd
|
||||
|
||||
|
||||
def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]:
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
parser = SentenceSplitter(chunk_size=chunk_size)
|
||||
llama_index_docs = [LlamaIndexDocument(text=text)]
|
||||
nodes = parser.get_nodes_from_documents(llama_index_docs)
|
||||
return [n.text for n in nodes]
|
||||
|
||||
|
||||
def truncate_text(text: str, max_length: int, encoding) -> str:
|
||||
# truncate the text based on max_length and encoding
|
||||
encoded_text = encoding.encode(text)[:max_length]
|
||||
return encoding.decode(encoded_text)
|
||||
|
||||
|
||||
def check_and_split_text(text: str, embedding_model: str) -> List[str]:
|
||||
"""Split text into chunks of max_length tokens or less"""
|
||||
|
||||
if embedding_model in EMBEDDING_TO_TOKENIZER_MAP:
|
||||
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_MAP[embedding_model])
|
||||
else:
|
||||
print(f"Warning: couldn't find tokenizer for model {embedding_model}, using default tokenizer {EMBEDDING_TO_TOKENIZER_DEFAULT}")
|
||||
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_DEFAULT)
|
||||
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
# determine max length
|
||||
if hasattr(encoding, "max_length"):
|
||||
# TODO(fix) this is broken
|
||||
max_length = encoding.max_length
|
||||
else:
|
||||
# TODO: figure out the real number
|
||||
printd(f"Warning: couldn't find max_length for tokenizer {embedding_model}, using default max_length 8191")
|
||||
max_length = 8191
|
||||
|
||||
# truncate text if too long
|
||||
if num_tokens > max_length:
|
||||
print(f"Warning: text is too long ({num_tokens} tokens), truncating to {max_length} tokens.")
|
||||
# First, apply any necessary formatting
|
||||
formatted_text = format_text(text, embedding_model)
|
||||
# Then truncate
|
||||
text = truncate_text(formatted_text, max_length, encoding)
|
||||
|
||||
return [text]
|
||||
@@ -634,6 +634,17 @@ class SyncServer(object):
|
||||
async def insert_archival_memory_async(
|
||||
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
|
||||
) -> List[Passage]:
|
||||
from letta.settings import settings
|
||||
from letta.utils import count_tokens
|
||||
|
||||
# Check token count against limit
|
||||
token_count = count_tokens(memory_contents)
|
||||
if token_count > settings.archival_memory_token_limit:
|
||||
raise LettaInvalidArgumentError(
|
||||
message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)",
|
||||
argument_name="memory_contents",
|
||||
)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import parse_and_chunk_text
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
@@ -427,7 +426,8 @@ class PassageManager:
|
||||
# Get or create the default archive for the agent
|
||||
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
|
||||
# TODO: check to make sure token count is okay for embedding model
|
||||
text_chunks = [text]
|
||||
|
||||
if not text_chunks:
|
||||
return []
|
||||
|
||||
@@ -339,6 +339,9 @@ class Settings(BaseSettings):
|
||||
# enabling letta_agent_v1 architecture
|
||||
use_letta_v1_agent: bool = False
|
||||
|
||||
# Archival memory token limit
|
||||
archival_memory_token_limit: int = 8192
|
||||
|
||||
@property
|
||||
def letta_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
@@ -636,6 +636,26 @@ def test_insert_archival_memory(client: LettaSDKClient, agent: AgentState):
|
||||
assert "This is a test passage" not in passage_texts, f"Memory was not deleted: {passage_texts}"
|
||||
|
||||
|
||||
def test_insert_archival_memory_exceeds_token_limit(client: LettaSDKClient, agent: AgentState):
|
||||
"""Test that inserting archival memory exceeding token limit raises an error."""
|
||||
from letta.settings import settings
|
||||
|
||||
# Create a text that exceeds the token limit (default 8192)
|
||||
# Each word is roughly 1-2 tokens, so we'll create a large enough text
|
||||
long_text = " ".join(["word"] * (settings.archival_memory_token_limit + 1000))
|
||||
|
||||
# Attempt to insert and expect an error
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
client.agents.passages.create(
|
||||
agent_id=agent.id,
|
||||
text=long_text,
|
||||
)
|
||||
|
||||
# Verify the error is an INVALID_ARGUMENT error
|
||||
assert exc_info.value.status_code == 400, f"Expected 400 status code, got {exc_info.value.status_code}"
|
||||
assert "token limit" in str(exc_info.value).lower(), f"Error message should mention token limit: {exc_info.value}"
|
||||
|
||||
|
||||
def test_search_archival_memory(client: LettaSDKClient, agent: AgentState):
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
Reference in New Issue
Block a user