chore: 0.8.10 release (#2710)

This commit is contained in:
Sarah Wooders
2025-07-06 20:52:06 -07:00
committed by GitHub
44 changed files with 1935 additions and 861 deletions

View File

@@ -1,24 +1,19 @@
import logging
import os
import requests
import socket
import threading
import time
from datetime import datetime, timezone
from typing import Generator
import pytest
import requests
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
from dotenv import load_dotenv
from letta_client import Letta, AsyncLetta
from letta_client import AsyncLetta, Letta
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
@@ -160,11 +155,13 @@ def dummy_beta_message_batch() -> BetaMessageBatch:
type="message_batch",
)
# --- Model Sweep ---
# Global flag to track server state
_server_started = False
_server_url = None
def _start_server_once() -> str:
"""Start server exactly once, return URL"""
global _server_started, _server_url
@@ -176,16 +173,18 @@ def _start_server_once() -> str:
# Check if already running
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', 8283)) == 0:
if s.connect_ex(("localhost", 8283)) == 0:
_server_started = True
_server_url = url
return url
# Start server (your existing logic)
if not os.getenv("LETTA_SERVER_URL"):
def _run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
thread = threading.Thread(target=_run_server, daemon=True)
@@ -209,15 +208,18 @@ def _start_server_once() -> str:
_server_url = url
return url
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""Return URL of already-started server"""
return _start_server_once()
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
@@ -274,14 +276,11 @@ def get_available_llm_configs() -> [LLMConfig]:
temp_client = Letta(base_url=server_url)
return temp_client.models.list()
# dynamically insert llm_config paramter at collection time
def pytest_generate_tests(metafunc):
"""Dynamically parametrize tests that need llm_config."""
if "llm_config" in metafunc.fixturenames:
configs = get_available_llm_configs()
if configs:
metafunc.parametrize(
"llm_config",
configs,
ids=[c.model for c in configs]
)
metafunc.parametrize("llm_config", configs, ids=[c.model for c in configs])

View File

@@ -1,19 +1,20 @@
#!/usr/bin/env python3
import json
import sys
import os
import sys
from collections import defaultdict
from datetime import datetime
def load_feature_mappings(config_file=None):
"""Load feature mappings from config file."""
if config_file is None:
# Default to feature_mappings.json in the same directory as this script
script_dir = os.path.dirname(os.path.abspath(__file__))
config_file = os.path.join(script_dir, 'feature_mappings.json')
config_file = os.path.join(script_dir, "feature_mappings.json")
try:
with open(config_file, 'r') as f:
with open(config_file, "r") as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: Could not find feature mappings config file '{config_file}'")
@@ -22,14 +23,15 @@ def load_feature_mappings(config_file=None):
print(f"Error: Invalid JSON in feature mappings config file '{config_file}'")
sys.exit(1)
def get_support_status(passed_tests, feature_tests):
"""Determine support status for a feature category."""
if not feature_tests:
return "" # Unknown - no tests for this feature
# Filter out error tests when checking for support
non_error_tests = [test for test in feature_tests if not test.endswith('_error')]
error_tests = [test for test in feature_tests if test.endswith('_error')]
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
error_tests = [test for test in feature_tests if test.endswith("_error")]
# Check which non-error tests passed
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
@@ -46,6 +48,7 @@ def get_support_status(passed_tests, feature_tests):
else:
return "⚠️" # Partial support
def categorize_tests(all_test_names, feature_mapping):
"""Categorize test names into feature buckets."""
categorized = {feature: [] for feature in feature_mapping.keys()}
@@ -58,6 +61,7 @@ def categorize_tests(all_test_names, feature_mapping):
return categorized
def calculate_support_score(feature_support, feature_order):
"""Calculate a numeric support score for ranking models.
@@ -83,86 +87,90 @@ def calculate_support_score(feature_support, feature_order):
# Unknown (❓) gets 0 points
return score
def calculate_provider_support_score(models_data, feature_order):
"""Calculate a provider-level support score based on all models' support scores."""
if not models_data:
return 0
# Calculate the average support score across all models in the provider
total_score = sum(model['support_score'] for model in models_data)
total_score = sum(model["support_score"] for model in models_data)
return total_score / len(models_data)
def get_test_function_line_numbers(test_file_path):
"""Extract line numbers for test functions from the test file."""
test_line_numbers = {}
try:
with open(test_file_path, 'r') as f:
with open(test_file_path, "r") as f:
lines = f.readlines()
for i, line in enumerate(lines, 1):
if 'def test_' in line and line.strip().startswith('def test_'):
if "def test_" in line and line.strip().startswith("def test_"):
# Extract function name
func_name = line.strip().split('def ')[1].split('(')[0]
func_name = line.strip().split("def ")[1].split("(")[0]
test_line_numbers[func_name] = i
except FileNotFoundError:
print(f"Warning: Could not find test file at {test_file_path}")
return test_line_numbers
def get_github_repo_info():
"""Get GitHub repository information from git remote."""
try:
# Try to get the GitHub repo URL from git remote
import subprocess
result = subprocess.run(['git', 'remote', 'get-url', 'origin'],
capture_output=True, text=True, cwd=os.path.dirname(__file__))
result = subprocess.run(["git", "remote", "get-url", "origin"], capture_output=True, text=True, cwd=os.path.dirname(__file__))
if result.returncode == 0:
remote_url = result.stdout.strip()
# Parse GitHub URL
if 'github.com' in remote_url:
if remote_url.startswith('https://'):
if "github.com" in remote_url:
if remote_url.startswith("https://"):
# https://github.com/user/repo.git -> user/repo
repo_path = remote_url.replace('https://github.com/', '').replace('.git', '')
elif remote_url.startswith('git@'):
repo_path = remote_url.replace("https://github.com/", "").replace(".git", "")
elif remote_url.startswith("git@"):
# git@github.com:user/repo.git -> user/repo
repo_path = remote_url.split(':')[1].replace('.git', '')
repo_path = remote_url.split(":")[1].replace(".git", "")
else:
return None
return repo_path
except:
pass
# Default fallback
return "letta-ai/letta"
def generate_test_details(model_info, feature_mapping):
"""Generate detailed test results for a model."""
details = []
# Get test function line numbers
script_dir = os.path.dirname(os.path.abspath(__file__))
test_file_path = os.path.join(script_dir, 'model_sweep.py')
test_file_path = os.path.join(script_dir, "model_sweep.py")
test_line_numbers = get_test_function_line_numbers(test_file_path)
# Use the main branch GitHub URL
base_github_url = "https://github.com/letta-ai/letta/blob/main/.github/scripts/model-sweep/model_sweep.py"
for feature, tests in model_info['categorized_tests'].items():
for feature, tests in model_info["categorized_tests"].items():
if not tests:
continue
details.append(f"### {feature}")
details.append("")
for test in sorted(tests):
if test in model_info['passed_tests']:
if test in model_info["passed_tests"]:
status = ""
elif test in model_info['failed_tests']:
elif test in model_info["failed_tests"]:
status = ""
else:
status = ""
# Create GitHub link if we have line number info
if test in test_line_numbers:
line_num = test_line_numbers[test]
@@ -171,16 +179,13 @@ def generate_test_details(model_info, feature_mapping):
else:
details.append(f"- {status} `{test}`")
details.append("")
return details
def calculate_column_widths(all_provider_data, feature_mapping):
"""Calculate the maximum width needed for each column across all providers."""
widths = {
'model': len('Model'),
'context_window': len('Context Window'),
'last_scanned': len('Last Scanned')
}
widths = {"model": len("Model"), "context_window": len("Context Window"), "last_scanned": len("Last Scanned")}
# Feature column widths
for feature in feature_mapping.keys():
@@ -191,19 +196,20 @@ def calculate_column_widths(all_provider_data, feature_mapping):
for model_info in provider_data:
# Model name width (including backticks)
model_width = len(f"`{model_info['name']}`")
widths['model'] = max(widths['model'], model_width)
widths["model"] = max(widths["model"], model_width)
# Context window width (with commas)
context_width = len(f"{model_info['context_window']:,}")
widths['context_window'] = max(widths['context_window'], context_width)
widths["context_window"] = max(widths["context_window"], context_width)
# Last scanned width
widths['last_scanned'] = max(widths['last_scanned'], len(str(model_info['last_scanned'])))
widths["last_scanned"] = max(widths["last_scanned"], len(str(model_info["last_scanned"])))
# Feature support symbols are always 2 chars, so no need to check
return widths
def process_model_sweep_report(input_file, output_file, config_file=None, debug=False):
"""Convert model sweep JSON data to MDX report."""
@@ -217,18 +223,18 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
# print()
# Read the JSON data
with open(input_file, 'r') as f:
with open(input_file, "r") as f:
data = json.load(f)
tests = data.get('tests', [])
tests = data.get("tests", [])
# if debug:
# print("DEBUG: Tests loaded:")
# print([test['outcome'] for test in tests if 'haiku' in test['nodeid']])
# Calculate summary statistics
providers = set(test['metadata']['llm_config']['provider_name'] for test in tests)
models = set(test['metadata']['llm_config']['model'] for test in tests)
providers = set(test["metadata"]["llm_config"]["provider_name"] for test in tests)
models = set(test["metadata"]["llm_config"]["model"] for test in tests)
total_tests = len(tests)
# Start building the MDX
@@ -246,13 +252,13 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
"",
f"Ran {total_tests} tests against {len(models)} models across {len(providers)} providers on {datetime.now().strftime('%B %dth, %Y')}",
"",
""
"",
]
# Group tests by provider
provider_groups = defaultdict(list)
for test in tests:
provider_name = test['metadata']['llm_config']['provider_name']
provider_name = test["metadata"]["llm_config"]["provider_name"]
provider_groups[provider_name].append(test)
# Process all providers first to collect model data
@@ -265,7 +271,7 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
# Group tests by model within this provider
model_groups = defaultdict(list)
for test in provider_tests:
model_name = test['metadata']['llm_config']['model']
model_name = test["metadata"]["llm_config"]["model"]
model_groups[model_name].append(test)
# Process all models to calculate support scores for ranking
@@ -283,15 +289,15 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
for test in model_tests:
# Extract test name from nodeid (split on :: and [)
test_name = test['nodeid'].split('::')[1].split('[')[0]
test_name = test["nodeid"].split("::")[1].split("[")[0]
all_test_names.add(test_name)
# if debug:
# print(f" Test name: {test_name}")
# print(f" Outcome: {test}")
if test['outcome'] == 'passed':
if test["outcome"] == "passed":
passed_tests.add(test_name)
elif test['outcome'] == 'failed':
elif test["outcome"] == "failed":
failed_tests.add(test_name)
# if debug:
@@ -319,16 +325,16 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
# print()
# Get context window and last scanned time
context_window = model_tests[0]['metadata']['llm_config']['context_window']
context_window = model_tests[0]["metadata"]["llm_config"]["context_window"]
# Try to get time_last_scanned from metadata, fallback to current time
try:
last_scanned = model_tests[0]['metadata'].get('time_last_scanned',
model_tests[0]['metadata'].get('timestamp',
datetime.now().isoformat()))
last_scanned = model_tests[0]["metadata"].get(
"time_last_scanned", model_tests[0]["metadata"].get("timestamp", datetime.now().isoformat())
)
# Format timestamp if it's a full ISO string
if 'T' in str(last_scanned):
last_scanned = str(last_scanned).split('T')[0] # Just the date part
if "T" in str(last_scanned):
last_scanned = str(last_scanned).split("T")[0] # Just the date part
except:
last_scanned = "Unknown"
@@ -337,19 +343,21 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
support_score = calculate_support_score(feature_support, feature_order)
# Store model data for sorting
model_data.append({
'name': model_name,
'feature_support': feature_support,
'context_window': context_window,
'last_scanned': last_scanned,
'support_score': support_score,
'failed_tests': failed_tests,
'passed_tests': passed_tests,
'categorized_tests': categorized_tests
})
model_data.append(
{
"name": model_name,
"feature_support": feature_support,
"context_window": context_window,
"last_scanned": last_scanned,
"support_score": support_score,
"failed_tests": failed_tests,
"passed_tests": passed_tests,
"categorized_tests": categorized_tests,
}
)
# Sort models by support score (descending) then by name (ascending)
model_data.sort(key=lambda x: (-x['support_score'], x['name']))
model_data.sort(key=lambda x: (-x["support_score"], x["name"]))
# Store provider data
all_provider_data[provider_name] = model_data
@@ -357,13 +365,10 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
# Calculate column widths for consistent formatting (add details column)
column_widths = calculate_column_widths(all_provider_data, feature_mapping)
column_widths['details'] = len('Details')
column_widths["details"] = len("Details")
# Sort providers by support score (descending) then by name (ascending)
sorted_providers = sorted(
provider_support_scores.keys(),
key=lambda x: (-provider_support_scores[x], x)
)
sorted_providers = sorted(provider_support_scores.keys(), key=lambda x: (-provider_support_scores[x], x))
# Generate tables for all providers first
for provider_name in sorted_providers:
@@ -377,47 +382,48 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
header_parts = [f"{'Model':<{column_widths['model']}}"]
for feature in feature_names:
header_parts.append(f"{feature:^{column_widths[feature]}}")
header_parts.extend([
f"{'Context Window':^{column_widths['context_window']}}",
f"{'Last Scanned':^{column_widths['last_scanned']}}",
f"{'Details':^{column_widths['details']}}"
])
header_parts.extend(
[
f"{'Context Window':^{column_widths['context_window']}}",
f"{'Last Scanned':^{column_widths['last_scanned']}}",
f"{'Details':^{column_widths['details']}}",
]
)
header_row = "| " + " | ".join(header_parts) + " |"
# Build separator row with left-aligned first column, centered others
separator_parts = [f"{'-' * column_widths['model']}"]
for feature in feature_names:
separator_parts.append(f":{'-' * (column_widths[feature] - 2)}:")
separator_parts.extend([
f":{'-' * (column_widths['context_window'] - 2)}:",
f":{'-' * (column_widths['last_scanned'] - 2)}:",
f":{'-' * (column_widths['details'] - 2)}:"
])
separator_parts.extend(
[
f":{'-' * (column_widths['context_window'] - 2)}:",
f":{'-' * (column_widths['last_scanned'] - 2)}:",
f":{'-' * (column_widths['details'] - 2)}:",
]
)
separator_row = "|" + "|".join(separator_parts) + "|"
# Add provider section without percentage
mdx_lines.extend([
f"## {provider_name}",
"",
header_row,
separator_row
])
mdx_lines.extend([f"## {provider_name}", "", header_row, separator_row])
# Generate table rows for sorted models with proper padding
for model_info in model_data:
# Create anchor for model details
model_anchor = model_info['name'].replace('/', '_').replace(':', '_').replace('-', '_').lower()
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
# Build row with left-aligned first column, centered others
row_parts = [f"`{model_info['name']}`".ljust(column_widths['model'])]
row_parts = [f"`{model_info['name']}`".ljust(column_widths["model"])]
for feature in feature_names:
row_parts.append(f"{model_info['feature_support'][feature]:^{column_widths[feature]}}")
row_parts.extend([
f"{model_info['context_window']:,}".center(column_widths['context_window']),
f"{model_info['last_scanned']}".center(column_widths['last_scanned']),
f"[View](#{details_anchor})".center(column_widths['details'])
])
row_parts.extend(
[
f"{model_info['context_window']:,}".center(column_widths["context_window"]),
f"{model_info['last_scanned']}".center(column_widths["last_scanned"]),
f"[View](#{details_anchor})".center(column_widths["details"]),
]
)
row = "| " + " | ".join(row_parts) + " |"
mdx_lines.append(row)
@@ -426,31 +432,32 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
# Add detailed test results section after all tables
mdx_lines.extend(["---", "", "# Detailed Test Results", ""])
for provider_name in sorted_providers:
model_data = all_provider_data[provider_name]
mdx_lines.extend([f"## {provider_name}", ""])
for model_info in model_data:
model_anchor = model_info['name'].replace('/', '_').replace(':', '_').replace('-', '_').lower()
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
mdx_lines.append(f"<a id=\"{details_anchor}\"></a>")
mdx_lines.append(f'<a id="{details_anchor}"></a>')
mdx_lines.append(f"### {model_info['name']}")
mdx_lines.append("")
# Add test details
test_details = generate_test_details(model_info, feature_mapping)
mdx_lines.extend(test_details)
# Add spacing between providers in details section
mdx_lines.extend(["", ""])
# Write the MDX file
with open(output_file, 'w') as f:
f.write('\n'.join(mdx_lines))
with open(output_file, "w") as f:
f.write("\n".join(mdx_lines))
print(f"Model sweep report saved to {output_file}")
def main():
input_file = "model_sweep_report.json"
output_file = "model_sweep_report.mdx"
@@ -483,5 +490,6 @@ def main():
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -29,7 +29,6 @@ from letta_client.types import (
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
# ------------------------------
# Helper Functions and Constants
# ------------------------------
@@ -109,6 +108,7 @@ requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def assert_greeting_with_assistant_message_response(
messages: List[Any],
streaming: bool = False,

View File

@@ -0,0 +1,33 @@
"""Add total_chunks and chunks_embedded to files
Revision ID: 47d2277e530d
Revises: 56254216524f
Create Date: 2025-07-03 14:32:08.539280
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "47d2277e530d"
down_revision: Union[str, None] = "56254216524f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("total_chunks", sa.Integer(), nullable=True))
op.add_column("files", sa.Column("chunks_embedded", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "chunks_embedded")
op.drop_column("files", "total_chunks")
# ### end Alembic commands ###

View File

@@ -5,7 +5,7 @@ try:
__version__ = version("letta")
except PackageNotFoundError:
# Fallback for development installations
__version__ = "0.8.9"
__version__ = "0.8.10"
if os.environ.get("LETTA_VERSION"):

View File

@@ -82,11 +82,16 @@ class LettaAgent(BaseAgent):
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
current_run_id: str | None = None,
## summarizer settings
summarizer_mode: SummarizationMode = summarizer_settings.mode,
# for static_buffer mode
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
message_buffer_min: int = summarizer_settings.message_buffer_min,
enable_summarization: bool = summarizer_settings.enable_summarization,
max_summarization_retries: int = summarizer_settings.max_summarization_retries,
# for partial_evict mode
partial_evict_summarizer_percentage: float = summarizer_settings.partial_evict_summarizer_percentage,
):
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
@@ -124,11 +129,13 @@ class LettaAgent(BaseAgent):
)
self.summarizer = Summarizer(
mode=SummarizationMode(summarizer_settings.mode),
mode=summarizer_mode,
# TODO consolidate to not use this, or push it into the Summarizer() class
summarizer_agent=self.summarization_agent,
# TODO: Make this configurable
message_buffer_limit=message_buffer_limit,
message_buffer_min=message_buffer_min,
partial_evict_summarizer_percentage=partial_evict_summarizer_percentage,
)
async def _check_run_cancellation(self) -> bool:
@@ -872,25 +879,35 @@ class LettaAgent(BaseAgent):
self.logger.warning(
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
)
new_in_context_messages, updated = self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, force=True, clear=True
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
force=True,
clear=True,
)
else:
new_in_context_messages, updated = self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
self.logger.info(
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
)
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
)
await self.agent_manager.set_in_context_messages_async(
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
agent_id=self.agent_id,
message_ids=[m.id for m in new_in_context_messages],
actor=self.actor,
)
return new_in_context_messages
@trace_method
async def summarize_conversation_history(self) -> AgentState:
"""Called when the developer explicitly triggers compaction via the API"""
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
message_ids = agent_state.message_ids
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
new_in_context_messages, updated = self.summarizer.summarize(
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=[], force=True
)
return await self.agent_manager.set_in_context_messages_async(

View File

@@ -295,7 +295,7 @@ class VoiceAgent(BaseAgent):
new_letta_messages = await self.message_manager.create_many_messages_async(letta_message_db_queue, actor=self.actor)
# TODO: Make this more general and configurable, less brittle
new_in_context_messages, updated = summarizer.summarize(
new_in_context_messages, updated = await summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
)

View File

@@ -90,7 +90,7 @@ class VoiceSleeptimeAgent(LettaAgent):
current_in_context_messages, new_in_context_messages, stop_reason, usage = await super()._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
new_in_context_messages, updated = self.summarizer.summarize(
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages
)
self.agent_manager.set_in_context_messages(

View File

@@ -364,3 +364,10 @@ REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
MAX_FILES_OPEN = 5
GET_PROVIDERS_TIMEOUT_SECONDS = 10
# Pinecone related fields
PINECONE_EMBEDDING_MODEL: str = "llama-text-embed-v2"
PINECONE_TEXT_FIELD_NAME = "chunk_text"
PINECONE_METRIC = "cosine"
PINECONE_CLOUD = "aws"
PINECONE_REGION = "us-east-1"

View File

@@ -65,7 +65,7 @@ async def grep_files(
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
async def semantic_search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
async def semantic_search_files(agent_state: "AgentState", query: str, limit: int = 5) -> List["FileMetadata"]:
"""
Get list of most relevant chunks from any file using vector/embedding search.
@@ -76,6 +76,7 @@ async def semantic_search_files(agent_state: "AgentState", query: str) -> List["
Args:
query (str): The search query.
limit: Maximum number of results to return (default: 5)
Returns:
List[FileMetadata]: List of matching files.

View File

@@ -29,7 +29,6 @@ def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> d
# "Field": Field,
}
env.update(globals())
# print("About to execute source code...")
exec(source_code, env)
# print("Source code executed successfully")

View File

@@ -0,0 +1,143 @@
from typing import Any, Dict, List
from pinecone import PineconeAsyncio
from letta.constants import PINECONE_CLOUD, PINECONE_EMBEDDING_MODEL, PINECONE_METRIC, PINECONE_REGION, PINECONE_TEXT_FIELD_NAME
from letta.log import get_logger
from letta.schemas.user import User
from letta.settings import settings
logger = get_logger(__name__)
def should_use_pinecone(verbose: bool = False):
if verbose:
logger.info(
"Pinecone check: enable_pinecone=%s, api_key=%s, agent_index=%s, source_index=%s",
settings.enable_pinecone,
bool(settings.pinecone_api_key),
bool(settings.pinecone_agent_index),
bool(settings.pinecone_source_index),
)
return settings.enable_pinecone and settings.pinecone_api_key and settings.pinecone_agent_index and settings.pinecone_source_index
async def upsert_pinecone_indices():
from pinecone import IndexEmbed, PineconeAsyncio
for index_name in get_pinecone_indices():
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
if not await pc.has_index(index_name):
await pc.create_index_for_model(
name=index_name,
cloud=PINECONE_CLOUD,
region=PINECONE_REGION,
embed=IndexEmbed(model=PINECONE_EMBEDDING_MODEL, field_map={"text": PINECONE_TEXT_FIELD_NAME}, metric=PINECONE_METRIC),
)
def get_pinecone_indices() -> List[str]:
return [settings.pinecone_agent_index, settings.pinecone_source_index]
async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User):
records = []
for i, chunk in enumerate(chunks):
record = {
"_id": f"{file_id}_{i}",
PINECONE_TEXT_FIELD_NAME: chunk,
"file_id": file_id,
"source_id": source_id,
}
records.append(record)
return await upsert_records_to_pinecone_index(records, actor)
async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
from pinecone.exceptions.exceptions import NotFoundException
namespace = actor.organization_id
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.delete(
filter={
"file_id": {"$eq": file_id},
},
namespace=namespace,
)
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")
async def delete_source_records_from_pinecone_index(source_id: str, actor: User):
from pinecone.exceptions.exceptions import NotFoundException
namespace = actor.organization_id
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.delete(filter={"source_id": {"$eq": source_id}}, namespace=namespace)
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {source_id} and {actor.organization_id}")
async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
await dense_index.upsert_records(actor.organization_id, records)
async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
namespace = actor.organization_id
try:
# Search the dense index with reranking
search_results = await dense_index.search(
namespace=namespace,
query={
"top_k": limit,
"inputs": {"text": query},
"filter": filter,
},
rerank={"model": "bge-reranker-v2-m3", "top_n": limit, "rank_fields": [PINECONE_TEXT_FIELD_NAME]},
)
return search_results
except Exception as e:
logger.warning(f"Failed to search Pinecone namespace {actor.organization_id}: {str(e)}")
raise e
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
from pinecone.exceptions.exceptions import NotFoundException
namespace = actor.organization_id
try:
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
description = await pc.describe_index(name=settings.pinecone_source_index)
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
kwargs = {"namespace": namespace, "prefix": file_id}
if limit is not None:
kwargs["limit"] = limit
if pagination_token is not None:
kwargs["pagination_token"] = pagination_token
try:
result = []
async for ids in dense_index.list(**kwargs):
result.extend(ids)
return result
except Exception as e:
logger.warning(f"Failed to list Pinecone namespace {actor.organization_id}: {str(e)}")
raise e
except NotFoundException:
logger.warning(f"Pinecone namespace {namespace} not found for {file_id} and {actor.organization_id}")

View File

@@ -216,6 +216,10 @@ class OpenAIClient(LLMClientBase):
# NOTE: the reasoners that don't support temperature require 1.0, not None
temperature=llm_config.temperature if supports_temperature_param(model) else 1.0,
)
if llm_config.frequency_penalty is not None:
data.frequency_penalty = llm_config.frequency_penalty
if tools and supports_parallel_tool_calling(model):
data.parallel_tool_calls = False

View File

@@ -60,6 +60,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Any error message encountered during processing.")
total_chunks: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Total number of chunks for the file.")
chunks_embedded: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="Number of chunks that have been embedded.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
@@ -112,6 +114,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
file_last_modified_date=self.file_last_modified_date,
processing_status=self.processing_status,
error_message=self.error_message,
total_chunks=self.total_chunks,
chunks_embedded=self.chunks_embedded,
created_at=self.created_at,
updated_at=self.updated_at,
is_deleted=self.is_deleted,

View File

@@ -1,14 +1,12 @@
WORD_LIMIT = 100
SYSTEM = f"""
Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
SYSTEM = f"""Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
The conversation you are given is a from a fixed context window and may not be complete.
Messages sent by the AI are marked with the 'assistant' role.
The AI 'assistant' can also make calls to functions, whose outputs can be seen in messages with the 'function' role.
The AI 'assistant' can also make calls to tools, whose outputs can be seen in messages with the 'tool' role.
Things the AI says in the message content are considered inner monologue and are not seen by the user.
The only AI messages seen by the user are from when the AI uses 'send_message'.
Messages the user sends are in the 'user' role.
The 'user' role is also used for important system events, such as login events and heartbeat events (heartbeats run the AI's program without user action, allowing the AI to act without prompting from the user sending them a message).
Summarize what happened in the conversation from the perspective of the AI (use the first person).
Summarize what happened in the conversation from the perspective of the AI (use the first person from the perspective of the AI).
Keep your summary less than {WORD_LIMIT} words, do NOT exceed this word limit.
Only output the summary, do NOT include anything else in your output.
"""
Only output the summary, do NOT include anything else in your output."""

View File

@@ -41,6 +41,8 @@ class FileMetadata(FileMetadataBase):
description="The current processing status of the file (e.g. pending, parsing, embedding, completed, error).",
)
error_message: Optional[str] = Field(default=None, description="Optional error message if the file failed processing.")
total_chunks: Optional[int] = Field(default=None, description="Total number of chunks for the file.")
chunks_embedded: Optional[int] = Field(default=None, description="Number of chunks that have been embedded.")
# orm metadata, optional fields
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
@@ -52,6 +54,10 @@ class FileMetadata(FileMetadataBase):
default=None, description="Optional full-text content of the file; only populated on demand due to its size."
)
def is_processing_terminal(self) -> bool:
"""Check if the file processing status is in a terminal state (completed or error)."""
return self.processing_status in (FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR)
class FileAgentBase(LettaBase):
"""Base class for the FileMetadata-⇄-Agent association schemas"""

View File

@@ -97,7 +97,7 @@ class LettaBase(BaseModel):
class OrmMetadataBase(LettaBase):
# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this object.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this object.")
created_at: Optional[datetime] = Field(None, description="The timestamp when the object was created.")
updated_at: Optional[datetime] = Field(None, description="The timestamp when the object was last updated.")
created_by_id: Optional[str] = Field(default=None, description="The id of the user that made this object.")
last_updated_by_id: Optional[str] = Field(default=None, description="The id of the user that made this object.")
created_at: Optional[datetime] = Field(default=None, description="The timestamp when the object was created.")
updated_at: Optional[datetime] = Field(default=None, description="The timestamp when the object was last updated.")

View File

@@ -72,7 +72,7 @@ class SystemMessage(LettaMessage):
content (str): The message content sent by the system
"""
message_type: Literal[MessageType.system_message] = Field(MessageType.system_message, description="The type of the message.")
message_type: Literal[MessageType.system_message] = Field(default=MessageType.system_message, description="The type of the message.")
content: str = Field(..., description="The message content sent by the system")
@@ -87,7 +87,7 @@ class UserMessage(LettaMessage):
content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
"""
message_type: Literal[MessageType.user_message] = Field(MessageType.user_message, description="The type of the message.")
message_type: Literal[MessageType.user_message] = Field(default=MessageType.user_message, description="The type of the message.")
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
...,
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
@@ -109,7 +109,9 @@ class ReasoningMessage(LettaMessage):
signature (Optional[str]): The model-generated signature of the reasoning step
"""
message_type: Literal[MessageType.reasoning_message] = Field(MessageType.reasoning_message, description="The type of the message.")
message_type: Literal[MessageType.reasoning_message] = Field(
default=MessageType.reasoning_message, description="The type of the message."
)
source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
reasoning: str
signature: Optional[str] = None
@@ -130,7 +132,7 @@ class HiddenReasoningMessage(LettaMessage):
"""
message_type: Literal[MessageType.hidden_reasoning_message] = Field(
MessageType.hidden_reasoning_message, description="The type of the message."
default=MessageType.hidden_reasoning_message, description="The type of the message."
)
state: Literal["redacted", "omitted"]
hidden_reasoning: Optional[str] = None
@@ -170,7 +172,9 @@ class ToolCallMessage(LettaMessage):
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
"""
message_type: Literal[MessageType.tool_call_message] = Field(MessageType.tool_call_message, description="The type of the message.")
message_type: Literal[MessageType.tool_call_message] = Field(
default=MessageType.tool_call_message, description="The type of the message."
)
tool_call: Union[ToolCall, ToolCallDelta]
def model_dump(self, *args, **kwargs):
@@ -222,7 +226,9 @@ class ToolReturnMessage(LettaMessage):
stderr (Optional[List(str)]): Captured stderr from the tool invocation
"""
message_type: Literal[MessageType.tool_return_message] = Field(MessageType.tool_return_message, description="The type of the message.")
message_type: Literal[MessageType.tool_return_message] = Field(
default=MessageType.tool_return_message, description="The type of the message."
)
tool_return: str
status: Literal["success", "error"]
tool_call_id: str
@@ -241,7 +247,9 @@ class AssistantMessage(LettaMessage):
content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
"""
message_type: Literal[MessageType.assistant_message] = Field(MessageType.assistant_message, description="The type of the message.")
message_type: Literal[MessageType.assistant_message] = Field(
default=MessageType.assistant_message, description="The type of the message."
)
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
...,
description="The message content sent by the agent (can be a string or an array of content parts)",

View File

@@ -24,7 +24,7 @@ class MessageContent(BaseModel):
class TextContent(MessageContent):
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
type: Literal[MessageContentType.text] = Field(default=MessageContentType.text, description="The type of the message.")
text: str = Field(..., description="The text content of the message.")
@@ -44,27 +44,27 @@ class ImageSource(BaseModel):
class UrlImage(ImageSource):
type: Literal[ImageSourceType.url] = Field(ImageSourceType.url, description="The source type for the image.")
type: Literal[ImageSourceType.url] = Field(default=ImageSourceType.url, description="The source type for the image.")
url: str = Field(..., description="The URL of the image.")
class Base64Image(ImageSource):
type: Literal[ImageSourceType.base64] = Field(ImageSourceType.base64, description="The source type for the image.")
type: Literal[ImageSourceType.base64] = Field(default=ImageSourceType.base64, description="The source type for the image.")
media_type: str = Field(..., description="The media type for the image.")
data: str = Field(..., description="The base64 encoded image data.")
detail: Optional[str] = Field(
None,
default=None,
description="What level of detail to use when processing and understanding the image (low, high, or auto to let the model decide)",
)
class LettaImage(ImageSource):
type: Literal[ImageSourceType.letta] = Field(ImageSourceType.letta, description="The source type for the image.")
type: Literal[ImageSourceType.letta] = Field(default=ImageSourceType.letta, description="The source type for the image.")
file_id: str = Field(..., description="The unique identifier of the image file persisted in storage.")
media_type: Optional[str] = Field(None, description="The media type for the image.")
data: Optional[str] = Field(None, description="The base64 encoded image data.")
media_type: Optional[str] = Field(default=None, description="The media type for the image.")
data: Optional[str] = Field(default=None, description="The base64 encoded image data.")
detail: Optional[str] = Field(
None,
default=None,
description="What level of detail to use when processing and understanding the image (low, high, or auto to let the model decide)",
)
@@ -73,7 +73,7 @@ ImageSourceUnion = Annotated[Union[UrlImage, Base64Image, LettaImage], Field(dis
class ImageContent(MessageContent):
type: Literal[MessageContentType.image] = Field(MessageContentType.image, description="The type of the message.")
type: Literal[MessageContentType.image] = Field(default=MessageContentType.image, description="The type of the message.")
source: ImageSourceUnion = Field(..., description="The source of the image.")
@@ -164,7 +164,7 @@ def get_letta_assistant_message_content_union_str_json_schema():
class ToolCallContent(MessageContent):
type: Literal[MessageContentType.tool_call] = Field(
MessageContentType.tool_call, description="Indicates this content represents a tool call event."
default=MessageContentType.tool_call, description="Indicates this content represents a tool call event."
)
id: str = Field(..., description="A unique identifier for this specific tool call instance.")
name: str = Field(..., description="The name of the tool being called.")
@@ -175,7 +175,7 @@ class ToolCallContent(MessageContent):
class ToolReturnContent(MessageContent):
type: Literal[MessageContentType.tool_return] = Field(
MessageContentType.tool_return, description="Indicates this content represents a tool return event."
default=MessageContentType.tool_return, description="Indicates this content represents a tool return event."
)
tool_call_id: str = Field(..., description="References the ID of the ToolCallContent that initiated this tool call.")
content: str = Field(..., description="The content returned by the tool execution.")
@@ -184,23 +184,23 @@ class ToolReturnContent(MessageContent):
class ReasoningContent(MessageContent):
type: Literal[MessageContentType.reasoning] = Field(
MessageContentType.reasoning, description="Indicates this is a reasoning/intermediate step."
default=MessageContentType.reasoning, description="Indicates this is a reasoning/intermediate step."
)
is_native: bool = Field(..., description="Whether the reasoning content was generated by a reasoner model that processed this step.")
reasoning: str = Field(..., description="The intermediate reasoning or thought process content.")
signature: Optional[str] = Field(None, description="A unique identifier for this reasoning step.")
signature: Optional[str] = Field(default=None, description="A unique identifier for this reasoning step.")
class RedactedReasoningContent(MessageContent):
type: Literal[MessageContentType.redacted_reasoning] = Field(
MessageContentType.redacted_reasoning, description="Indicates this is a redacted thinking step."
default=MessageContentType.redacted_reasoning, description="Indicates this is a redacted thinking step."
)
data: str = Field(..., description="The redacted or filtered intermediate reasoning content.")
class OmittedReasoningContent(MessageContent):
type: Literal[MessageContentType.omitted_reasoning] = Field(
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
default=MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
)
# NOTE: dropping because we don't track this kind of information for the other reasoning types
# tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")

View File

@@ -78,6 +78,10 @@ class LLMConfig(BaseModel):
0,
description="Configurable thinking budget for extended thinking. Used for enable_reasoner and also for Google Vertex models like Gemini 2.5 Flash. Minimum value is 1024 when used with enable_reasoner.",
)
frequency_penalty: Optional[float] = Field(
None, # Can also deafult to 0.0?
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. From OpenAI: Number between -2.0 and 2.0.",
)
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())

View File

@@ -84,11 +84,11 @@ class MessageCreate(BaseModel):
description="The content of the message.",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
)
name: Optional[str] = Field(None, description="The name of the participant.")
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with")
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
name: Optional[str] = Field(default=None, description="The name of the participant.")
otid: Optional[str] = Field(default=None, description="The offline threading id associated with this message")
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
@@ -101,9 +101,9 @@ class MessageCreate(BaseModel):
class MessageUpdate(BaseModel):
"""Request to update a message"""
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
role: Optional[MessageRole] = Field(default=None, description="The role of the participant.")
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
None,
default=None,
description="The content of the message.",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
)
@@ -112,11 +112,11 @@ class MessageUpdate(BaseModel):
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
# NOTE: we probably shouldn't allow updating the model field, otherwise this loses meaning
# model: Optional[str] = Field(None, description="The model used to make the function call.")
name: Optional[str] = Field(None, description="The name of the participant.")
name: Optional[str] = Field(default=None, description="The name of the participant.")
# NOTE: we probably shouldn't allow updating the created_at field, right?
# created_at: Optional[datetime] = Field(None, description="The time the message was created.")
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
tool_calls: Optional[List[OpenAIToolCall,]] = Field(default=None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(default=None, description="The id of the tool call.")
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
@@ -150,28 +150,28 @@ class Message(BaseMessage):
"""
id: str = BaseMessage.generate_id_field()
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")
organization_id: Optional[str] = Field(default=None, description="The unique identifier of the organization.")
agent_id: Optional[str] = Field(default=None, description="The unique identifier of the agent.")
model: Optional[str] = Field(default=None, description="The model used to make the function call.")
# Basic OpenAI-style fields
role: MessageRole = Field(..., description="The role of the participant.")
content: Optional[List[LettaMessageContentUnion]] = Field(None, description="The content of the message.")
content: Optional[List[LettaMessageContentUnion]] = Field(default=None, description="The content of the message.")
# NOTE: in OpenAI, this field is only used for roles 'user', 'assistant', and 'function' (now deprecated). 'tool' does not use it.
name: Optional[str] = Field(
None,
default=None,
description="For role user/assistant: the (optional) name of the participant. For role tool/function: the name of the function called.",
)
tool_calls: Optional[List[OpenAIToolCall]] = Field(
None, description="The list of tool calls requested. Only applicable for role assistant."
default=None, description="The list of tool calls requested. Only applicable for role assistant."
)
tool_call_id: Optional[str] = Field(None, description="The ID of the tool call. Only applicable for role tool.")
tool_call_id: Optional[str] = Field(default=None, description="The ID of the tool call. Only applicable for role tool.")
# Extras
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with")
step_id: Optional[str] = Field(default=None, description="The id of the step that this message was created in.")
otid: Optional[str] = Field(default=None, description="The offline threading id associated with this message")
tool_returns: Optional[List[ToolReturn]] = Field(default=None, description="Tool execution return information for prior tool calls")
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@@ -482,7 +482,9 @@ class Message(BaseMessage):
# TODO(caren) implicit support for only non-parts/list content types
if openai_message_dict["content"] is not None and type(openai_message_dict["content"]) is not str:
raise ValueError(f"Invalid content type: {type(openai_message_dict['content'])}")
content = [TextContent(text=openai_message_dict["content"])] if openai_message_dict["content"] else []
content: List[LettaMessageContentUnion] = (
[TextContent(text=openai_message_dict["content"])] if openai_message_dict["content"] else []
)
# TODO(caren) bad assumption here that "reasoning_content" always comes before "redacted_reasoning_content"
if "reasoning_content" in openai_message_dict and openai_message_dict["reasoning_content"]:
@@ -491,14 +493,16 @@ class Message(BaseMessage):
reasoning=openai_message_dict["reasoning_content"],
is_native=True,
signature=(
openai_message_dict["reasoning_content_signature"] if openai_message_dict["reasoning_content_signature"] else None
str(openai_message_dict["reasoning_content_signature"])
if "reasoning_content_signature" in openai_message_dict
else None
),
),
)
if "redacted_reasoning_content" in openai_message_dict and openai_message_dict["redacted_reasoning_content"]:
content.append(
RedactedReasoningContent(
data=openai_message_dict["redacted_reasoning_content"] if "redacted_reasoning_content" in openai_message_dict else None,
data=str(openai_message_dict["redacted_reasoning_content"]),
),
)
if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]:
@@ -694,7 +698,7 @@ class Message(BaseMessage):
elif self.role == "assistant":
assert self.tool_calls is not None or text_content is not None
openai_message = {
"content": None if put_inner_thoughts_in_kwargs else text_content,
"content": None if (put_inner_thoughts_in_kwargs and self.tool_calls is not None) else text_content,
"role": self.role,
}
@@ -733,7 +737,7 @@ class Message(BaseMessage):
else:
warnings.warn(f"Using OpenAI with invalid 'name' field (name={self.name} role={self.role}).")
if parse_content_parts:
if parse_content_parts and self.content is not None:
for content in self.content:
if isinstance(content, ReasoningContent):
openai_message["reasoning_content"] = content.reasoning
@@ -819,7 +823,7 @@ class Message(BaseMessage):
}
content = []
# COT / reasoning / thinking
if len(self.content) > 1:
if self.content is not None and len(self.content) > 1:
for content_part in self.content:
if isinstance(content_part, ReasoningContent):
content.append(
@@ -1154,6 +1158,6 @@ class Message(BaseMessage):
class ToolReturn(BaseModel):
status: Literal["success", "error"] = Field(..., description="The status of the tool call")
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the tool invocation")
stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation")
# func_return: Optional[Any] = Field(None, description="The function return object")

View File

@@ -324,18 +324,25 @@ class OpenAIProvider(Provider):
else:
handle = self.get_handle(model_name)
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=handle,
provider_name=self.name,
provider_category=self.provider_category,
)
llm_config = LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=handle,
provider_name=self.name,
provider_category=self.provider_category,
)
# gpt-4o-mini has started to regress with pretty bad emoji spam loops
# this is to counteract that
if "gpt-4o-mini" in model_name:
llm_config.frequency_penalty = 1.0
if "gpt-4.1-mini" in model_name:
llm_config.frequency_penalty = 1.0
configs.append(llm_config)
# for OpenAI, sort in reverse order
if self.base_url == "https://api.openai.com/v1":
# alphnumeric sort

View File

@@ -17,6 +17,7 @@ from letta.__init__ import __version__ as letta_version
from letta.agents.exceptions import IncompatibleAgentType
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.helpers.pinecone_utils import get_pinecone_indices, should_use_pinecone, upsert_pinecone_indices
from letta.jobs.scheduler import start_scheduler_with_leader_election
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
@@ -127,6 +128,16 @@ async def lifespan(app_: FastAPI):
db_registry.initialize_async()
logger.info(f"[Worker {worker_id}] Database connections initialized")
if should_use_pinecone():
if settings.upsert_pinecone_indices:
logger.info(f"[Worker {worker_id}] Upserting pinecone indices: {get_pinecone_indices()}")
await upsert_pinecone_indices()
logger.info(f"[Worker {worker_id}] Upserted pinecone indices")
else:
logger.info(f"[Worker {worker_id}] Enabled pinecone")
else:
logger.info(f"[Worker {worker_id}] Disabled pinecone")
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
global server
try:

View File

@@ -38,6 +38,7 @@ from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.summarizer.enums import SummarizationMode
from letta.services.telemetry_manager import NoopTelemetryManager
from letta.settings import settings
from letta.utils import safe_create_task
@@ -750,6 +751,12 @@ async def send_message(
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
current_run_id=run.id,
# summarizer settings to be added here
summarizer_mode=(
SummarizationMode.STATIC_MESSAGE_BUFFER
if agent.agent_type == AgentType.voice_convo_agent
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
),
)
result = await agent_loop.step(
@@ -878,6 +885,12 @@ async def send_message_streaming(
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
current_run_id=run.id,
# summarizer settings to be added here
summarizer_mode=(
SummarizationMode.STATIC_MESSAGE_BUFFER
if agent.agent_type == AgentType.voice_convo_agent
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
@@ -1014,6 +1027,12 @@ async def _process_message_background(
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
# summarizer settings to be added here
summarizer_mode=(
SummarizationMode.STATIC_MESSAGE_BUFFER
if agent.agent_type == AgentType.voice_convo_agent
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
),
)
result = await agent_loop.step(

View File

@@ -9,6 +9,12 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
from starlette import status
import letta.constants as constants
from letta.helpers.pinecone_utils import (
delete_file_records_from_pinecone_index,
delete_source_records_from_pinecone_index,
list_pinecone_index_for_files,
should_use_pinecone,
)
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -22,6 +28,7 @@ from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
from letta.services.file_processor.embedder.pinecone_embedder import PineconeEmbedder
from letta.services.file_processor.file_processor import FileProcessor
from letta.services.file_processor.file_types import (
get_allowed_media_types,
@@ -163,6 +170,10 @@ async def delete_source(
files = await server.file_manager.list_files(source_id, actor)
file_ids = [f.id for f in files]
if should_use_pinecone():
logger.info(f"Deleting source {source_id} from pinecone index")
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
for agent_state in agent_states:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
@@ -326,16 +337,24 @@ async def get_file_metadata(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Verify the source exists and user has access
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if not source:
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
# Get file metadata using the file manager
file_metadata = await server.file_manager.get_file_by_id(
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
)
if should_use_pinecone() and not file_metadata.is_processing_terminal():
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks)
logger.info(f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} in organization {actor.organization_id}")
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
if len(ids) != file_metadata.total_chunks:
file_status = file_metadata.processing_status
else:
file_status = FileProcessingStatus.COMPLETED
await server.file_manager.update_file_status(
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
)
if not file_metadata:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
@@ -364,6 +383,10 @@ async def delete_file_from_source(
await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
if should_use_pinecone():
logger.info(f"Deleting file {file_id} from pinecone index")
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
if deleted_file is None:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
@@ -402,8 +425,14 @@ async def load_file_to_source_cloud(
):
file_processor = MistralFileParser()
text_chunker = LlamaIndexChunker(chunk_size=embedding_config.embedding_chunk_size)
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
using_pinecone = should_use_pinecone()
if using_pinecone:
embedder = PineconeEmbedder()
else:
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(
file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor, using_pinecone=using_pinecone
)
await file_processor.process(
server=server, agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata
)

View File

@@ -109,15 +109,17 @@ class FileManager:
actor: PydanticUser,
processing_status: Optional[FileProcessingStatus] = None,
error_message: Optional[str] = None,
total_chunks: Optional[int] = None,
chunks_embedded: Optional[int] = None,
) -> PydanticFileMetadata:
"""
Update processing_status and/or error_message on a FileMetadata row.
Update processing_status, error_message, total_chunks, and/or chunks_embedded on a FileMetadata row.
* 1st round-trip → UPDATE
* 2nd round-trip → SELECT fresh row (same as read_async)
"""
if processing_status is None and error_message is None:
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
raise ValueError("Nothing to update")
values: dict[str, object] = {"updated_at": datetime.utcnow()}
@@ -125,6 +127,10 @@ class FileManager:
values["processing_status"] = processing_status
if error_message is not None:
values["error_message"] = error_message
if total_chunks is not None:
values["total_chunks"] = total_chunks
if chunks_embedded is not None:
values["chunks_embedded"] = chunks_embedded
async with db_registry.async_session() as session:
# Fast in-place update no ORM hydration

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
from letta.log import get_logger
from letta.schemas.passage import Passage
from letta.schemas.user import User
logger = get_logger(__name__)
class BaseEmbedder(ABC):
"""Abstract base class for embedding generation"""
@abstractmethod
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings for chunks with batching and concurrent processing"""

View File

@@ -9,12 +9,13 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
from letta.settings import model_settings
logger = get_logger(__name__)
class OpenAIEmbedder:
class OpenAIEmbedder(BaseEmbedder):
"""OpenAI-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
@@ -24,6 +25,7 @@ class OpenAIEmbedder:
else EmbeddingConfig.default_config(model_name="letta")
)
self.embedding_config = embedding_config or self.default_embedding_config
self.max_concurrent_requests = 20
# TODO: Unify to global OpenAI client
self.client: OpenAIClient = cast(
@@ -34,7 +36,6 @@ class OpenAIEmbedder:
actor=None, # Not necessary
),
)
self.max_concurrent_requests = 20
@trace_method
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:

View File

@@ -0,0 +1,74 @@
from typing import List
from letta.helpers.pinecone_utils import upsert_file_records_to_pinecone_index
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
try:
PINECONE_AVAILABLE = True
except ImportError:
PINECONE_AVAILABLE = False
logger = get_logger(__name__)
class PineconeEmbedder(BaseEmbedder):
"""Pinecone-based embedding generation"""
def __init__(self):
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone package is not installed. Install it with: pip install pinecone")
super().__init__()
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings and upsert to Pinecone, then return Passage objects"""
if not chunks:
return []
logger.info(f"Upserting {len(chunks)} chunks to Pinecone using namespace {source_id}")
log_event(
"embedder.generation_started",
{
"total_chunks": len(chunks),
"file_id": file_id,
"source_id": source_id,
},
)
# Upsert records to Pinecone using source_id as namespace
try:
await upsert_file_records_to_pinecone_index(file_id=file_id, source_id=source_id, chunks=chunks, actor=actor)
logger.info(f"Successfully kicked off upserting {len(chunks)} records to Pinecone")
log_event(
"embedder.upsert_started",
{"records_upserted": len(chunks), "namespace": source_id, "file_id": file_id},
)
except Exception as e:
logger.error(f"Failed to upsert records to Pinecone: {str(e)}")
log_event("embedder.upsert_failed", {"error": str(e), "error_type": type(e).__name__})
raise
# Create Passage objects (without embeddings since Pinecone handles them)
passages = []
for i, text in enumerate(chunks):
passage = Passage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=None, # Pinecone handles embeddings internally
embedding_config=None, # None
organization_id=actor.organization_id,
)
passages.append(passage)
logger.info(f"Successfully created {len(passages)} passages")
log_event(
"embedder.generation_completed",
{"passages_created": len(passages), "total_chunks_processed": len(chunks), "file_id": file_id, "source_id": source_id},
)
return passages

View File

@@ -11,7 +11,7 @@ from letta.server.server import SyncServer
from letta.services.file_manager import FileManager
from letta.services.file_processor.chunker.line_chunker import LineChunker
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.services.job_manager import JobManager
from letta.services.passage_manager import PassageManager
@@ -27,8 +27,9 @@ class FileProcessor:
self,
file_parser: MistralFileParser,
text_chunker: LlamaIndexChunker,
embedder: OpenAIEmbedder,
embedder: BaseEmbedder,
actor: User,
using_pinecone: bool,
max_file_size: int = 50 * 1024 * 1024, # 50MB default
):
self.file_parser = file_parser
@@ -41,6 +42,7 @@ class FileProcessor:
self.passage_manager = PassageManager()
self.job_manager = JobManager()
self.actor = actor
self.using_pinecone = using_pinecone
# TODO: Factor this function out of SyncServer
@trace_method
@@ -109,7 +111,7 @@ class FileProcessor:
logger.info("Chunking extracted text")
log_event("file_processor.chunking_started", {"filename": filename, "pages_to_process": len(ocr_response.pages)})
all_passages = []
all_chunks = []
for page in ocr_response.pages:
chunks = self.text_chunker.chunk_text(page)
@@ -118,24 +120,17 @@ class FileProcessor:
log_event("file_processor.chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)})
raise ValueError("No chunks created from text")
passages = await self.embedder.generate_embedded_passages(
file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor
)
log_event(
"file_processor.page_processed",
{
"filename": filename,
"page_index": ocr_response.pages.index(page),
"chunks_created": len(chunks),
"passages_generated": len(passages),
},
)
all_passages.extend(passages)
all_chunks.extend(self.text_chunker.chunk_text(page))
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages, file_metadata=file_metadata, actor=self.actor
all_passages = await self.embedder.generate_embedded_passages(
file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor
)
log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
if not self.using_pinecone:
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages, file_metadata=file_metadata, actor=self.actor
)
log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
log_event(
@@ -149,9 +144,14 @@ class FileProcessor:
)
# update job status
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
)
if not self.using_pinecone:
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
)
else:
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, total_chunks=len(all_passages), chunks_embedded=0
)
return all_passages

View File

@@ -115,10 +115,6 @@ class JobManager:
job.completed_at = get_utc_time().replace(tzinfo=None)
if job.callback_url:
await self._dispatch_callback_async(job)
else:
logger.info(f"Job does not contain callback url: {job}")
else:
logger.info(f"Job update is not terminal {job_update}")
# Save the updated job to the database
await job.update_async(db_session=session, actor=actor)

View File

@@ -19,7 +19,6 @@ class SourceManager:
@trace_method
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
# Try getting the source first by id
db_source = await self.get_source_by_id(source.id, actor=actor)
if db_source:
return db_source

View File

@@ -7,3 +7,4 @@ class SummarizationMode(str, Enum):
"""
STATIC_MESSAGE_BUFFER = "static_message_buffer_mode"
PARTIAL_EVICT_MESSAGE_BUFFER = "partial_evict_message_buffer_mode"

View File

@@ -4,13 +4,19 @@ import traceback
from typing import List, Optional, Tuple, Union
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, MESSAGE_SUMMARY_REQUEST_ACK
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.prompts import gpt_summarize
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.services.summarizer.enums import SummarizationMode
from letta.system import package_summarize_message_no_counts
from letta.templates.template_helper import render_template
logger = get_logger(__name__)
@@ -29,18 +35,24 @@ class Summarizer:
summarizer_agent: Optional[Union[EphemeralSummaryAgent, "VoiceSleeptimeAgent"]] = None,
message_buffer_limit: int = 10,
message_buffer_min: int = 3,
partial_evict_summarizer_percentage: float = 0.30,
):
self.mode = mode
# Need to do validation on this
# TODO: Move this to config
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
self.summarizer_agent = summarizer_agent
# TODO: Move this to config
self.partial_evict_summarizer_percentage = partial_evict_summarizer_percentage
@trace_method
def summarize(
self, in_context_messages: List[Message], new_letta_messages: List[Message], force: bool = False, clear: bool = False
async def summarize(
self,
in_context_messages: List[Message],
new_letta_messages: List[Message],
force: bool = False,
clear: bool = False,
) -> Tuple[List[Message], bool]:
"""
Summarizes or trims in_context_messages according to the chosen mode,
@@ -58,7 +70,19 @@ class Summarizer:
(could be appended to the conversation if desired)
"""
if self.mode == SummarizationMode.STATIC_MESSAGE_BUFFER:
return self._static_buffer_summarization(in_context_messages, new_letta_messages, force=force, clear=clear)
return self._static_buffer_summarization(
in_context_messages,
new_letta_messages,
force=force,
clear=clear,
)
elif self.mode == SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER:
return await self._partial_evict_buffer_summarization(
in_context_messages,
new_letta_messages,
force=force,
clear=clear,
)
else:
# Fallback or future logic
return in_context_messages, False
@@ -75,9 +99,131 @@ class Summarizer:
task.add_done_callback(callback)
return task
def _static_buffer_summarization(
self, in_context_messages: List[Message], new_letta_messages: List[Message], force: bool = False, clear: bool = False
async def _partial_evict_buffer_summarization(
self,
in_context_messages: List[Message],
new_letta_messages: List[Message],
force: bool = False,
clear: bool = False,
) -> Tuple[List[Message], bool]:
"""Summarization as implemented in the original MemGPT loop, but using message count instead of token count.
Evict a partial amount of messages, and replace message[1] with a recursive summary.
Note that this can't be made sync, because we're waiting on the summary to inject it into the context window,
unlike the version that writes it to a block.
Unless force is True, don't summarize.
Ignore clear, we don't use it.
"""
all_in_context_messages = in_context_messages + new_letta_messages
if not force:
logger.debug("Not forcing summarization, returning in-context messages as is.")
return all_in_context_messages, False
# Very ugly code to pull LLMConfig etc from the SummarizerAgent if we're not using it for anything else
assert self.summarizer_agent is not None
# First step: determine how many messages to retain
total_message_count = len(all_in_context_messages)
assert self.partial_evict_summarizer_percentage >= 0.0 and self.partial_evict_summarizer_percentage <= 1.0
target_message_start = round((1.0 - self.partial_evict_summarizer_percentage) * total_message_count)
logger.info(f"Target message count: {total_message_count}->{(total_message_count-target_message_start)}")
# The summary message we'll insert is role 'user' (vs 'assistant', 'tool', or 'system')
# We are going to put it at index 1 (index 0 is the system message)
# That means that index 2 needs to be role 'assistant', so walk up the list starting at
# the target_message_count and find the first assistant message
for i in range(target_message_start, total_message_count):
if all_in_context_messages[i].role == MessageRole.assistant:
assistant_message_index = i
break
else:
raise ValueError(f"No assistant message found from indices {target_message_start} to {total_message_count}")
# The sequence to summarize is index 1 -> assistant_message_index
messages_to_summarize = all_in_context_messages[1:assistant_message_index]
logger.info(f"Eviction indices: {1}->{assistant_message_index}(/{total_message_count})")
# Dynamically get the LLMConfig from the summarizer agent
# Pretty cringe code here that we need the agent for this but we don't use it
agent_state = await self.summarizer_agent.agent_manager.get_agent_by_id_async(
agent_id=self.summarizer_agent.agent_id, actor=self.summarizer_agent.actor
)
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
summary_message_str = await simple_summary(
messages=messages_to_summarize,
llm_config=agent_state.llm_config,
actor=self.summarizer_agent.actor,
include_ack=True,
)
# TODO add counts back
# Recall message count
# num_recall_messages_current = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
# num_messages_evicted = len(messages_to_summarize)
# num_recall_messages_hidden = num_recall_messages_total - len()
# Create the summary message
summary_message_str_packed = package_summarize_message_no_counts(
summary=summary_message_str,
timezone=agent_state.timezone,
)
summary_message_obj = convert_message_creates_to_messages(
message_creates=[
MessageCreate(
role=MessageRole.user,
content=[TextContent(text=summary_message_str_packed)],
)
],
agent_id=agent_state.id,
timezone=agent_state.timezone,
# We already packed, don't pack again
wrap_user_message=False,
wrap_system_message=False,
)[0]
# Create the message in the DB
await self.summarizer_agent.message_manager.create_many_messages_async(
pydantic_msgs=[summary_message_obj],
actor=self.summarizer_agent.actor,
)
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
return [all_in_context_messages[0], summary_message_obj] + updated_in_context_messages, True
def _static_buffer_summarization(
self,
in_context_messages: List[Message],
new_letta_messages: List[Message],
force: bool = False,
clear: bool = False,
) -> Tuple[List[Message], bool]:
"""
Implements static buffer summarization by maintaining a fixed-size message buffer (< N messages).
Logic:
1. Combine existing context messages with new messages
2. If total messages <= buffer limit and not forced, return unchanged
3. Calculate how many messages to retain (0 if clear=True, otherwise message_buffer_min)
4. Find the trim index to keep the most recent messages while preserving user message boundaries
5. Evict older messages (everything between system message and trim index)
6. If summarizer agent is available, trigger background summarization of evicted messages
7. Return updated context with system message + retained recent messages
Args:
in_context_messages: Existing conversation context messages
new_letta_messages: Newly added messages to append
force: Force summarization even if buffer limit not exceeded
clear: Clear all messages except system message (retain_count = 0)
Returns:
Tuple of (updated_messages, was_summarized)
- updated_messages: New context after trimming/summarization
- was_summarized: True if messages were evicted and summarization triggered
"""
all_in_context_messages = in_context_messages + new_letta_messages
if len(all_in_context_messages) <= self.message_buffer_limit and not force:
@@ -139,6 +285,91 @@ class Summarizer:
return [all_in_context_messages[0]] + updated_in_context_messages, True
def simple_formatter(messages: List[Message], include_system: bool = False) -> str:
"""Go from an OpenAI-style list of messages to a concatenated string"""
parsed_messages = [message.to_openai_dict() for message in messages if message.role != MessageRole.system or include_system]
return "\n".join(json.dumps(msg) for msg in parsed_messages)
def simple_message_wrapper(openai_msg: dict) -> Message:
"""Extremely simple way to map from role/content to Message object w/ throwaway dummy fields"""
if "role" not in openai_msg:
raise ValueError(f"Missing role in openai_msg: {openai_msg}")
if "content" not in openai_msg:
raise ValueError(f"Missing content in openai_msg: {openai_msg}")
if openai_msg["role"] == "user":
return Message(
role=MessageRole.user,
content=[TextContent(text=openai_msg["content"])],
)
elif openai_msg["role"] == "assistant":
return Message(
role=MessageRole.assistant,
content=[TextContent(text=openai_msg["content"])],
)
elif openai_msg["role"] == "system":
return Message(
role=MessageRole.system,
content=[TextContent(text=openai_msg["content"])],
)
else:
raise ValueError(f"Unknown role: {openai_msg['role']}")
async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor: User, include_ack: bool = True) -> str:
"""Generate a simple summary from a list of messages.
Intentionally kept functional due to the simplicity of the prompt.
"""
# Create an LLMClient from the config
llm_client = LLMClient.create(
provider_type=llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=actor,
)
assert llm_client is not None
# Prepare the messages payload to send to the LLM
system_prompt = gpt_summarize.SYSTEM
summary_transcript = simple_formatter(messages)
if include_ack:
input_messages = [
{"role": "system", "content": system_prompt},
{"role": "assistant", "content": MESSAGE_SUMMARY_REQUEST_ACK},
{"role": "user", "content": summary_transcript},
]
else:
input_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": summary_transcript},
]
print("messages going to summarizer:", input_messages)
input_messages_obj = [simple_message_wrapper(msg) for msg in input_messages]
print("messages going to summarizer (objs):", input_messages_obj)
request_data = llm_client.build_request_data(input_messages_obj, llm_config, tools=[])
print("request data:", request_data)
# NOTE: we should disable the inner_thoughts_in_kwargs here, because we don't use it
# I'm leaving it commented it out for now for safety but is fine assuming the var here is a copy not a reference
# llm_config.put_inner_thoughts_in_kwargs = False
response_data = await llm_client.request_async(request_data, llm_config)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, llm_config)
if response.choices[0].message.content is None:
logger.warning("No content returned from summarizer")
# TODO raise an error error instead?
# return "[Summary failed to generate]"
raise Exception("Summary failed to generate")
else:
summary = response.choices[0].message.content.strip()
return summary
def format_transcript(messages: List[Message], include_system: bool = False) -> List[str]:
"""
Turn a list of Message objects into a human-readable transcript.

View File

@@ -2,8 +2,9 @@ import asyncio
import re
from typing import Any, Dict, List, Optional
from letta.constants import MAX_FILES_OPEN
from letta.constants import MAX_FILES_OPEN, PINECONE_TEXT_FIELD_NAME
from letta.functions.types import FileOpenRequest
from letta.helpers.pinecone_utils import search_pinecone_index, should_use_pinecone
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -463,14 +464,15 @@ class LettaFileToolExecutor(ToolExecutor):
return "\n".join(formatted_results)
@trace_method
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 5) -> str:
"""
Search for text within attached files using semantic search and return passages with their source filenames.
Uses Pinecone if configured, otherwise falls back to traditional search.
Args:
agent_state: Current agent state
query: Search query for semantic matching
limit: Maximum number of results to return (default: 10)
limit: Maximum number of results to return (default: 5)
Returns:
Formatted string with search results in IDE/terminal style
@@ -485,6 +487,110 @@ class LettaFileToolExecutor(ToolExecutor):
self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})")
# Check if Pinecone is enabled and use it if available
if should_use_pinecone():
return await self._search_files_pinecone(agent_state, query, limit)
else:
return await self._search_files_traditional(agent_state, query, limit)
async def _search_files_pinecone(self, agent_state: AgentState, query: str, limit: int) -> str:
"""Search files using Pinecone vector database."""
# Extract unique source_ids
# TODO: Inefficient
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
source_ids = [source.id for source in attached_sources]
if not source_ids:
return f"No valid source IDs found for attached files"
# Get all attached files for this agent
file_agents = await self.files_agents_manager.list_files_for_agent(agent_id=agent_state.id, actor=self.actor)
if not file_agents:
return "No files are currently attached to search"
results = []
total_hits = 0
files_with_matches = {}
try:
filter = {"source_id": {"$in": source_ids}}
search_results = await search_pinecone_index(query, limit, filter, self.actor)
# Process search results
if "result" in search_results and "hits" in search_results["result"]:
for hit in search_results["result"]["hits"]:
if total_hits >= limit:
break
total_hits += 1
# Extract hit information
hit_id = hit.get("_id", "unknown")
score = hit.get("_score", 0.0)
fields = hit.get("fields", {})
text = fields.get(PINECONE_TEXT_FIELD_NAME, "")
file_id = fields.get("file_id", "")
# Find corresponding file name
file_name = "Unknown File"
for fa in file_agents:
if fa.file_id == file_id:
file_name = fa.file_name
break
# Group by file name
if file_name not in files_with_matches:
files_with_matches[file_name] = []
files_with_matches[file_name].append({"text": text, "score": score, "hit_id": hit_id})
except Exception as e:
self.logger.error(f"Pinecone search failed: {str(e)}")
raise e
if not files_with_matches:
return f"No semantic matches found in Pinecone for query: '{query}'"
# Format results
passage_num = 0
for file_name, matches in files_with_matches.items():
for match in matches:
passage_num += 1
# Format each passage with terminal-style header
score_display = f"(score: {match['score']:.3f})"
passage_header = f"\n=== {file_name} (passage #{passage_num}) {score_display} ==="
# Format the passage text
passage_text = match["text"].strip()
lines = passage_text.splitlines()
formatted_lines = []
for line in lines[:20]: # Limit to first 20 lines per passage
formatted_lines.append(f" {line}")
if len(lines) > 20:
formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]")
passage_content = "\n".join(formatted_lines)
results.append(f"{passage_header}\n{passage_content}")
# Mark access for files that had matches
if files_with_matches:
matched_file_names = [name for name in files_with_matches.keys() if name != "Unknown File"]
if matched_file_names:
await self.files_agents_manager.mark_access_bulk(agent_id=agent_state.id, file_names=matched_file_names, actor=self.actor)
# Create summary header
file_count = len(files_with_matches)
summary = f"Found {total_hits} Pinecone matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'"
# Combine all results
formatted_results = [summary, "=" * len(summary)] + results
self.logger.info(f"Pinecone search completed: {total_hits} matches across {file_count} files")
return "\n".join(formatted_results)
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
"""Traditional search using existing passage manager."""
# Get semantic search results
passages = await self.agent_manager.list_source_passages_async(
actor=self.actor,

View File

@@ -14,7 +14,6 @@ from letta.otel.tracing import trace_method
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.settings import settings
from letta.utils import enforce_types
logger = get_logger(__name__)

View File

@@ -39,12 +39,17 @@ class ToolSettings(BaseSettings):
class SummarizerSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_summarizer_", extra="ignore")
mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER
# mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER
mode: SummarizationMode = SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
message_buffer_limit: int = 60
message_buffer_min: int = 15
enable_summarization: bool = True
max_summarization_retries: int = 3
# partial evict summarizer percentage
# eviction based on percentage of message count, not token count
partial_evict_summarizer_percentage: float = 0.30
# TODO(cliandy): the below settings are tied to old summarization and should be deprecated or moved
# Controls if we should evict all messages
# TODO: Can refactor this into an enum if we have a bunch of different kinds of summarizers
@@ -253,6 +258,13 @@ class Settings(BaseSettings):
llm_request_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM requests in seconds")
llm_stream_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds")
# For embeddings
enable_pinecone: bool = False
pinecone_api_key: Optional[str] = None
pinecone_source_index: Optional[str] = "sources"
pinecone_agent_index: Optional[str] = "recall"
upsert_pinecone_indices: bool = False
@property
def letta_pg_uri(self) -> str:
if self.pg_uri:

View File

@@ -188,6 +188,22 @@ def package_summarize_message(summary, summary_message_count, hidden_message_cou
return json_dumps(packaged_message)
def package_summarize_message_no_counts(summary, timezone):
context_message = (
f"Note: prior messages have been hidden from view due to conversation memory constraints.\n"
+ f"The following is a summary of the previous messages:\n {summary}"
)
formatted_time = get_local_time(timezone=timezone)
packaged_message = {
"type": "system_alert",
"message": context_message,
"time": formatted_time,
}
return json_dumps(packaged_message)
def package_summarize_message_no_summary(hidden_message_count, message=None, timezone=None):
"""Add useful metadata to the summary message"""

1312
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.8.9"
version = "0.8.10"
packages = [
{include = "letta"},
]
@@ -98,6 +98,7 @@ redis = {version = "^6.2.0", optional = true}
structlog = "^25.4.0"
certifi = "^2025.6.15"
aioboto3 = {version = "^14.3.0", optional = true}
pinecone = {extras = ["asyncio"], version = "^7.3.0"}
aiosqlite = "^0.21.0"
@@ -121,6 +122,7 @@ black = "^24.4.2"
ipykernel = "^6.29.5"
ipdb = "^0.13.13"
pytest-mock = "^3.14.0"
pinecone = "^7.3.0"
[tool.poetry.group."dev,tests".dependencies]

View File

@@ -1,5 +1,5 @@
{
"context_window": 8192,
"context_window": 128000,
"model": "gpt-4o-mini",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",

View File

@@ -28,6 +28,24 @@ def disable_e2b_api_key() -> Generator[None, None, None]:
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def disable_pinecone() -> Generator[None, None, None]:
"""
Temporarily disables Pinecone by setting `settings.enable_pinecone` to False
and `settings.pinecone_api_key` to None for the duration of the test.
Restores the original values afterward.
"""
from letta.settings import settings
original_enable_pinecone = settings.enable_pinecone
original_pinecone_api_key = settings.pinecone_api_key
settings.enable_pinecone = False
settings.pinecone_api_key = None
yield
settings.enable_pinecone = original_enable_pinecone
settings.pinecone_api_key = original_pinecone_api_key
@pytest.fixture
def check_e2b_key_is_set():
from letta.settings import tool_settings

View File

@@ -3320,7 +3320,7 @@ async def test_update_tool_pip_requirements(server: SyncServer, print_tool, defa
# Add pip requirements to existing tool
pip_reqs = [
PipRequirement(name="pandas", version="1.5.0"),
PipRequirement(name="matplotlib"),
PipRequirement(name="sumy"),
]
tool_update = ToolUpdate(pip_requirements=pip_reqs)
@@ -3334,7 +3334,7 @@ async def test_update_tool_pip_requirements(server: SyncServer, print_tool, defa
assert len(updated_tool.pip_requirements) == 2
assert updated_tool.pip_requirements[0].name == "pandas"
assert updated_tool.pip_requirements[0].version == "1.5.0"
assert updated_tool.pip_requirements[1].name == "matplotlib"
assert updated_tool.pip_requirements[1].name == "sumy"
assert updated_tool.pip_requirements[1].version is None
@@ -5218,6 +5218,41 @@ async def test_update_file_status_error_only(server, default_user, default_sourc
assert updated.processing_status == FileProcessingStatus.PENDING # default from creation
@pytest.mark.asyncio
async def test_update_file_status_with_chunks(server, default_user, default_source):
"""Update chunk progress fields along with status."""
meta = PydanticFileMetadata(
file_name="chunks_test.txt",
file_path="/tmp/chunks_test.txt",
file_type="text/plain",
file_size=500,
source_id=default_source.id,
)
created = await server.file_manager.create_file(file_metadata=meta, actor=default_user)
# Update with chunk progress
updated = await server.file_manager.update_file_status(
file_id=created.id,
actor=default_user,
processing_status=FileProcessingStatus.EMBEDDING,
total_chunks=100,
chunks_embedded=50,
)
assert updated.processing_status == FileProcessingStatus.EMBEDDING
assert updated.total_chunks == 100
assert updated.chunks_embedded == 50
# Update only chunk progress
updated = await server.file_manager.update_file_status(
file_id=created.id,
actor=default_user,
chunks_embedded=100,
)
assert updated.chunks_embedded == 100
assert updated.total_chunks == 100 # unchanged
assert updated.processing_status == FileProcessingStatus.EMBEDDING # unchanged
@pytest.mark.asyncio
async def test_upsert_file_content_basic(server: SyncServer, default_user, default_source, async_session):
"""Test creating and updating file content with upsert_file_content()."""

View File

@@ -9,9 +9,10 @@ from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client.types import AgentState
from letta.constants import FILES_TOOLS
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
from letta.orm.enums import ToolType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
from tests.utils import wait_for_server
# Constants
@@ -49,7 +50,7 @@ def client() -> LettaSDKClient:
yield client
def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 30):
def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 60):
"""Helper function to upload a file and wait for processing to complete"""
with open(file_path, "rb") as f:
file_metadata = client.sources.files.upload(source_id=source_id, file=f)
@@ -70,7 +71,7 @@ def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str,
@pytest.fixture
def agent_state(client: LettaSDKClient):
def agent_state(disable_pinecone, client: LettaSDKClient):
open_file_tool = client.tools.list(name="open_files")[0]
search_files_tool = client.tools.list(name="semantic_search_files")[0]
grep_tool = client.tools.list(name="grep_files")[0]
@@ -93,7 +94,7 @@ def agent_state(client: LettaSDKClient):
# Tests
def test_auto_attach_detach_files_tools(client: LettaSDKClient):
def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient):
"""Test automatic attachment and detachment of file tools when managing agent sources."""
# Create agent with basic configuration
agent = client.agents.create(
@@ -164,6 +165,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
],
)
def test_file_upload_creates_source_blocks_correctly(
disable_pinecone,
client: LettaSDKClient,
agent_state: AgentState,
file_path: str,
@@ -204,7 +206,7 @@ def test_file_upload_creates_source_blocks_correctly(
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -240,7 +242,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
assert not any("test" in b.value for b in blocks)
def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -270,7 +272,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
assert not any("test" in b.value for b in blocks)
def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -377,7 +379,7 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat
print("✓ File successfully opened with different range - content differs as expected")
def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -423,7 +425,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -465,7 +467,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -517,7 +519,7 @@ def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state:
assert "513:" in tool_return_message.tool_return
def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: LettaSDKClient):
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient):
"""Test that creating an agent with source_ids parameter correctly creates source blocks."""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -560,7 +562,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le
assert file_tools == set(FILES_TOOLS)
def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState):
def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -623,7 +625,7 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
)
def test_duplicate_file_renaming(client: LettaSDKClient):
def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
# Create a new source
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
@@ -662,7 +664,7 @@ def test_duplicate_file_renaming(client: LettaSDKClient):
print(f" File {i+1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
def test_open_files_schema_descriptions(client: LettaSDKClient):
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
"""Test that open_files tool schema contains correct descriptions from docstring"""
# Get the open_files tool
@@ -743,3 +745,132 @@ def test_open_files_schema_descriptions(client: LettaSDKClient):
expected_length_desc = "Optional number of lines to view from offset (inclusive). If not specified, views to end of file."
assert length_prop["description"] == expected_length_desc
assert length_prop["type"] == "integer"
# --- Pinecone Tests ---
def test_pinecone_search_files_tool(client: LettaSDKClient):
"""Test that search_files tool uses Pinecone when enabled"""
from letta.helpers.pinecone_utils import should_use_pinecone
if not should_use_pinecone(verbose=True):
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
print("Testing Pinecone search_files tool functionality")
# Create agent with file tools
agent = client.agents.create(
name="test_pinecone_agent",
memory_blocks=[
CreateBlock(label="human", value="username: testuser"),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
)
# Create source and attach to agent
source = client.sources.create(name="test_pinecone_source", embedding="openai/text-embedding-3-small")
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
# Upload a file with searchable content
file_path = "tests/data/long_test.txt"
upload_file_and_wait(client, source.id, file_path)
# Test semantic search using Pinecone
search_response = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Use the semantic_search_files tool to search for 'electoral history' in the files.")],
)
# Verify tool was called successfully
tool_calls = [msg for msg in search_response.messages if msg.message_type == "tool_call_message"]
assert len(tool_calls) > 0, "No tool calls found"
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
# Verify tool returned results
tool_returns = [msg for msg in search_response.messages if msg.message_type == "tool_return_message"]
assert len(tool_returns) > 0, "No tool returns found"
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
# Check that results contain expected content
search_results = tool_returns[0].tool_return
print(search_results)
assert (
"electoral" in search_results.lower() or "history" in search_results.lower()
), f"Search results should contain relevant content: {search_results}"
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
"""Test that file and source deletion removes records from Pinecone"""
import asyncio
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
if not should_use_pinecone():
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
print("Testing Pinecone file and source deletion lifecycle")
# Create source
source = client.sources.create(name="test_lifecycle_source", embedding="openai/text-embedding-3-small")
# Upload multiple files and wait for processing
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
uploaded_files = []
for file_path in file_paths:
file_metadata = upload_file_and_wait(client, source.id, file_path)
uploaded_files.append(file_metadata)
# Get temp user for Pinecone operations
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
# Test file-level deletion first
if len(uploaded_files) > 1:
file_to_delete = uploaded_files[0]
# Check records for the specific file using list function
records_before = asyncio.run(list_pinecone_index_for_files(file_to_delete.id, user))
print(f"Found {len(records_before)} records for file before deletion")
# Delete the file
client.sources.files.delete(source_id=source.id, file_id=file_to_delete.id)
# Allow time for deletion to propagate
time.sleep(2)
# Verify file records are removed
records_after = asyncio.run(list_pinecone_index_for_files(file_to_delete.id, user))
print(f"Found {len(records_after)} records for file after deletion")
assert len(records_after) == 0, f"File records should be removed from Pinecone after deletion, but found {len(records_after)}"
# Test source-level deletion - check remaining files
# Check records for remaining files
remaining_records = []
for file_metadata in uploaded_files[1:]: # Skip the already deleted file
file_records = asyncio.run(list_pinecone_index_for_files(file_metadata.id, user))
remaining_records.extend(file_records)
records_before = len(remaining_records)
print(f"Found {records_before} records for remaining files before source deletion")
# Delete the entire source
client.sources.delete(source_id=source.id)
# Allow time for deletion to propagate
time.sleep(3)
# Verify all remaining file records are removed
records_after = []
for file_metadata in uploaded_files[1:]:
file_records = asyncio.run(list_pinecone_index_for_files(file_metadata.id, user))
records_after.extend(file_records)
print(f"Found {len(records_after)} records for files after source deletion")
assert (
len(records_after) == 0
), f"All source records should be removed from Pinecone after source deletion, but found {len(records_after)}"
print("✓ Pinecone lifecycle verified - namespace is clean after source deletion")