feat: Implement grep tool (#2694)
This commit is contained in:
@@ -32,12 +32,13 @@ async def close_file(agent_state: "AgentState", file_name: str) -> str:
|
||||
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
|
||||
|
||||
|
||||
async def grep(agent_state: "AgentState", pattern: str) -> str:
|
||||
async def grep(agent_state: "AgentState", pattern: str, include: Optional[str] = None) -> str:
|
||||
"""
|
||||
Grep tool to search files across data sources with keywords.
|
||||
Grep tool to search files across data sources with a keyword or regex pattern.
|
||||
|
||||
Args:
|
||||
pattern (str): Keyword or regex pattern to search.
|
||||
pattern (str): Keyword or regex pattern to search within file contents.
|
||||
include (Optional[str]): Optional keyword or regex pattern to filter filenames to include in the search.
|
||||
|
||||
Returns:
|
||||
str: Matching lines or summary output.
|
||||
|
||||
@@ -11,10 +11,9 @@ class LineChunker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# TODO: Make this more general beyond Mistral
|
||||
def chunk_text(self, text: str, start: Optional[int] = None, end: Optional[int] = None) -> List[str]:
|
||||
def chunk_text(self, text: str, start: Optional[int] = None, end: Optional[int] = None, add_metadata: bool = True) -> List[str]:
|
||||
"""Split lines"""
|
||||
content_lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||
content_lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
total_lines = len(content_lines)
|
||||
|
||||
if start and end:
|
||||
@@ -23,12 +22,13 @@ class LineChunker:
|
||||
else:
|
||||
line_offset = 0
|
||||
|
||||
content_lines = [f"Line {i + line_offset}: {line}" for i, line in enumerate(content_lines)]
|
||||
content_lines = [f"{i + line_offset}: {line}" for i, line in enumerate(content_lines)]
|
||||
|
||||
# Add metadata about total lines
|
||||
if start and end:
|
||||
content_lines.insert(0, f"[Viewing lines {start} to {end} (out of {total_lines} lines)]")
|
||||
else:
|
||||
content_lines.insert(0, f"[Viewing file start (out of {total_lines} lines)]")
|
||||
if add_metadata:
|
||||
if start and end:
|
||||
content_lines.insert(0, f"[Viewing lines {start} to {end-1} (out of {total_lines} lines)]")
|
||||
else:
|
||||
content_lines.insert(0, f"[Viewing file start (out of {total_lines} lines)]")
|
||||
|
||||
return content_lines
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -19,6 +22,15 @@ from letta.utils import get_friendly_error_msg
|
||||
class LettaFileToolExecutor(ToolExecutor):
|
||||
"""Executor for Letta file tools with direct implementation of functions."""
|
||||
|
||||
# Production safety constants
|
||||
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50MB limit per file
|
||||
MAX_TOTAL_CONTENT_SIZE = 200 * 1024 * 1024 # 200MB total across all files
|
||||
MAX_REGEX_COMPLEXITY = 1000 # Prevent catastrophic backtracking
|
||||
MAX_MATCHES_PER_FILE = 20 # Limit matches per file
|
||||
MAX_TOTAL_MATCHES = 50 # Global match limit
|
||||
GREP_TIMEOUT_SECONDS = 30 # Max time for grep operation
|
||||
MAX_CONTEXT_LINES = 1 # Lines of context around matches
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_manager: MessageManager,
|
||||
@@ -38,6 +50,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
# TODO: This should be passed in to for testing purposes
|
||||
self.files_agents_manager = FileAgentManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -120,19 +133,279 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
)
|
||||
return "Success"
|
||||
|
||||
async def grep(self, agent_state: AgentState, pattern: str) -> str:
|
||||
"""Stub for grep tool."""
|
||||
raise NotImplementedError
|
||||
def _validate_regex_pattern(self, pattern: str) -> None:
|
||||
"""Validate regex pattern to prevent catastrophic backtracking."""
|
||||
if len(pattern) > self.MAX_REGEX_COMPLEXITY:
|
||||
raise ValueError(f"Pattern too complex: {len(pattern)} chars > {self.MAX_REGEX_COMPLEXITY} limit")
|
||||
|
||||
# TODO: Make this paginated?
|
||||
async def search_files(self, agent_state: AgentState, query: str) -> List[str]:
|
||||
"""Search for text within attached files and return passages with their source filenames."""
|
||||
passages = await self.agent_manager.list_source_passages_async(actor=self.actor, agent_id=agent_state.id, query_text=query)
|
||||
formatted_results = []
|
||||
for p in passages:
|
||||
if p.file_name:
|
||||
formatted_result = f"[{p.file_name}]:\n{p.text}"
|
||||
# Test compile the pattern to catch syntax errors early
|
||||
try:
|
||||
re.compile(pattern, re.IGNORECASE | re.MULTILINE)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern: {e}")
|
||||
|
||||
def _get_context_lines(self, text: str, match_line_idx: int, total_lines: int) -> List[str]:
|
||||
"""Get context lines around a match using LineChunker."""
|
||||
start_idx = max(0, match_line_idx - self.MAX_CONTEXT_LINES)
|
||||
end_idx = min(total_lines, match_line_idx + self.MAX_CONTEXT_LINES + 1)
|
||||
|
||||
# Use LineChunker to get formatted lines with numbers
|
||||
chunker = LineChunker()
|
||||
context_lines = chunker.chunk_text(text, start=start_idx, end=end_idx, add_metadata=False)
|
||||
|
||||
# Add match indicator
|
||||
formatted_lines = []
|
||||
for line in context_lines:
|
||||
if line and ":" in line:
|
||||
line_num_str = line.split(":")[0].strip()
|
||||
try:
|
||||
line_num = int(line_num_str)
|
||||
prefix = ">" if line_num == match_line_idx + 1 else " "
|
||||
formatted_lines.append(f"{prefix} {line}")
|
||||
except ValueError:
|
||||
formatted_lines.append(f" {line}")
|
||||
else:
|
||||
formatted_result = p.text
|
||||
formatted_results.append(formatted_result)
|
||||
return formatted_results
|
||||
formatted_lines.append(f" {line}")
|
||||
|
||||
return formatted_lines
|
||||
|
||||
async def grep(self, agent_state: AgentState, pattern: str, include: Optional[str] = None) -> str:
|
||||
"""
|
||||
Search for pattern in all attached files and return matches with context.
|
||||
|
||||
Args:
|
||||
agent_state: Current agent state
|
||||
pattern: Regular expression pattern to search for
|
||||
include: Optional pattern to filter filenames to include in the search
|
||||
|
||||
Returns:
|
||||
Formatted string with search results, file names, line numbers, and context
|
||||
"""
|
||||
if not pattern or not pattern.strip():
|
||||
raise ValueError("Empty search pattern provided")
|
||||
|
||||
pattern = pattern.strip()
|
||||
self._validate_regex_pattern(pattern)
|
||||
|
||||
# Validate include pattern if provided
|
||||
include_regex = None
|
||||
if include and include.strip():
|
||||
include = include.strip()
|
||||
# Convert glob pattern to regex if it looks like a glob pattern
|
||||
if "*" in include and not any(c in include for c in ["^", "$", "(", ")", "[", "]", "{", "}", "\\", "+"]):
|
||||
# Simple glob to regex conversion
|
||||
include_pattern = include.replace(".", r"\.").replace("*", ".*").replace("?", ".")
|
||||
if not include_pattern.endswith("$"):
|
||||
include_pattern += "$"
|
||||
else:
|
||||
include_pattern = include
|
||||
|
||||
self._validate_regex_pattern(include_pattern)
|
||||
include_regex = re.compile(include_pattern, re.IGNORECASE)
|
||||
|
||||
# Get all attached files for this agent
|
||||
file_agents = await self.files_agents_manager.list_files_for_agent(agent_id=agent_state.id, actor=self.actor)
|
||||
|
||||
if not file_agents:
|
||||
return "No files are currently attached to search"
|
||||
|
||||
# Filter files by filename pattern if include is specified
|
||||
if include_regex:
|
||||
original_count = len(file_agents)
|
||||
file_agents = [fa for fa in file_agents if include_regex.search(fa.file_name)]
|
||||
if not file_agents:
|
||||
return f"No files match the filename pattern '{include}' (filtered {original_count} files)"
|
||||
|
||||
# Compile regex pattern with appropriate flags
|
||||
regex_flags = re.MULTILINE
|
||||
regex_flags |= re.IGNORECASE
|
||||
|
||||
pattern_regex = re.compile(pattern, regex_flags)
|
||||
|
||||
results = []
|
||||
total_matches = 0
|
||||
total_content_size = 0
|
||||
files_processed = 0
|
||||
files_skipped = 0
|
||||
|
||||
# Use asyncio timeout to prevent hanging
|
||||
async def _search_files():
|
||||
nonlocal results, total_matches, total_content_size, files_processed, files_skipped
|
||||
|
||||
for file_agent in file_agents:
|
||||
# Load file content
|
||||
file = await self.source_manager.get_file_by_id(file_id=file_agent.file_id, actor=self.actor, include_content=True)
|
||||
|
||||
if not file or not file.content:
|
||||
files_skipped += 1
|
||||
self.logger.warning(f"Grep: Skipping file {file_agent.file_name} - no content available")
|
||||
continue
|
||||
|
||||
# Check individual file size
|
||||
content_size = len(file.content.encode("utf-8"))
|
||||
if content_size > self.MAX_FILE_SIZE_BYTES:
|
||||
files_skipped += 1
|
||||
self.logger.warning(
|
||||
f"Grep: Skipping file {file.file_name} - too large ({content_size:,} bytes > {self.MAX_FILE_SIZE_BYTES:,} limit)"
|
||||
)
|
||||
results.append(f"[SKIPPED] {file.file_name}: File too large ({content_size:,} bytes)")
|
||||
continue
|
||||
|
||||
# Check total content size across all files
|
||||
total_content_size += content_size
|
||||
if total_content_size > self.MAX_TOTAL_CONTENT_SIZE:
|
||||
files_skipped += 1
|
||||
self.logger.warning(
|
||||
f"Grep: Skipping file {file.file_name} - total content size limit exceeded ({total_content_size:,} bytes > {self.MAX_TOTAL_CONTENT_SIZE:,} limit)"
|
||||
)
|
||||
results.append(f"[SKIPPED] {file.file_name}: Total content size limit exceeded")
|
||||
break
|
||||
|
||||
files_processed += 1
|
||||
file_matches = 0
|
||||
|
||||
# Use LineChunker to get all lines with proper formatting
|
||||
chunker = LineChunker()
|
||||
formatted_lines = chunker.chunk_text(file.content)
|
||||
|
||||
# Remove metadata header
|
||||
if formatted_lines and formatted_lines[0].startswith("[Viewing"):
|
||||
formatted_lines = formatted_lines[1:]
|
||||
|
||||
# Search for matches in formatted lines
|
||||
for formatted_line in formatted_lines:
|
||||
if total_matches >= self.MAX_TOTAL_MATCHES:
|
||||
results.append(f"[TRUNCATED] Maximum total matches ({self.MAX_TOTAL_MATCHES}) reached")
|
||||
return
|
||||
|
||||
if file_matches >= self.MAX_MATCHES_PER_FILE:
|
||||
results.append(f"[TRUNCATED] {file.file_name}: Maximum matches per file ({self.MAX_MATCHES_PER_FILE}) reached")
|
||||
break
|
||||
|
||||
# Extract line number and content from formatted line
|
||||
if ":" in formatted_line:
|
||||
try:
|
||||
line_parts = formatted_line.split(":", 1)
|
||||
line_num = int(line_parts[0].strip())
|
||||
line_content = line_parts[1].strip() if len(line_parts) > 1 else ""
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if pattern_regex.search(line_content):
|
||||
# Get context around the match (convert back to 0-based indexing)
|
||||
context_lines = self._get_context_lines(file.content, line_num - 1, len(file.content.splitlines()))
|
||||
|
||||
# Format the match result
|
||||
match_header = f"\n=== {file.file_name}:{line_num} ==="
|
||||
match_content = "\n".join(context_lines)
|
||||
results.append(f"{match_header}\n{match_content}")
|
||||
|
||||
file_matches += 1
|
||||
total_matches += 1
|
||||
|
||||
# Break if global limits reached
|
||||
if total_matches >= self.MAX_TOTAL_MATCHES:
|
||||
break
|
||||
|
||||
# Execute with timeout
|
||||
await asyncio.wait_for(_search_files(), timeout=self.GREP_TIMEOUT_SECONDS)
|
||||
|
||||
# Format final results
|
||||
if not results or total_matches == 0:
|
||||
summary = f"No matches found for pattern: '{pattern}'"
|
||||
if include:
|
||||
summary += f" in files matching '{include}'"
|
||||
if files_skipped > 0:
|
||||
summary += f" (searched {files_processed} files, skipped {files_skipped})"
|
||||
return summary
|
||||
|
||||
# Add summary header
|
||||
summary_parts = [f"Found {total_matches} matches"]
|
||||
if files_processed > 0:
|
||||
summary_parts.append(f"in {files_processed} files")
|
||||
if files_skipped > 0:
|
||||
summary_parts.append(f"({files_skipped} files skipped)")
|
||||
|
||||
summary = " ".join(summary_parts) + f" for pattern: '{pattern}'"
|
||||
if include:
|
||||
summary += f" in files matching '{include}'"
|
||||
|
||||
# Combine all results
|
||||
formatted_results = [summary, "=" * len(summary)] + results
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
|
||||
"""
|
||||
Search for text within attached files using semantic search and return passages with their source filenames.
|
||||
|
||||
Args:
|
||||
agent_state: Current agent state
|
||||
query: Search query for semantic matching
|
||||
limit: Maximum number of results to return (default: 10)
|
||||
|
||||
Returns:
|
||||
Formatted string with search results in IDE/terminal style
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
raise ValueError("Empty search query provided")
|
||||
|
||||
query = query.strip()
|
||||
|
||||
# Apply reasonable limit
|
||||
limit = min(limit, self.MAX_TOTAL_MATCHES)
|
||||
|
||||
self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})")
|
||||
|
||||
# Get semantic search results
|
||||
passages = await self.agent_manager.list_source_passages_async(actor=self.actor, agent_id=agent_state.id, query_text=query)
|
||||
|
||||
if not passages:
|
||||
return f"No semantic matches found for query: '{query}'"
|
||||
|
||||
# Limit results
|
||||
passages = passages[:limit]
|
||||
|
||||
# Group passages by file for better organization
|
||||
files_with_passages = {}
|
||||
for p in passages:
|
||||
file_name = p.file_name if p.file_name else "Unknown File"
|
||||
if file_name not in files_with_passages:
|
||||
files_with_passages[file_name] = []
|
||||
files_with_passages[file_name].append(p)
|
||||
|
||||
results = []
|
||||
total_passages = 0
|
||||
|
||||
for file_name, file_passages in files_with_passages.items():
|
||||
for passage in file_passages:
|
||||
total_passages += 1
|
||||
|
||||
# Format each passage with terminal-style header
|
||||
passage_header = f"\n=== {file_name} (passage #{total_passages}) ==="
|
||||
|
||||
# Format the passage text with some basic formatting
|
||||
passage_text = passage.text.strip()
|
||||
|
||||
# Format the passage text without line numbers
|
||||
lines = passage_text.splitlines()
|
||||
formatted_lines = []
|
||||
for line in lines[:20]: # Limit to first 20 lines per passage
|
||||
formatted_lines.append(f" {line}")
|
||||
|
||||
if len(lines) > 20:
|
||||
formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]")
|
||||
|
||||
passage_content = "\n".join(formatted_lines)
|
||||
results.append(f"{passage_header}\n{passage_content}")
|
||||
|
||||
# Create summary header
|
||||
file_count = len(files_with_passages)
|
||||
summary = f"Found {total_passages} semantic matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'"
|
||||
|
||||
# Combine all results
|
||||
formatted_results = [summary, "=" * len(summary)] + results
|
||||
|
||||
self.logger.info(f"Semantic search completed: {total_passages} matches across {file_count} files")
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
535
poetry.lock
generated
535
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -73,7 +73,7 @@ llama-index = "^0.12.2"
|
||||
llama-index-embeddings-openai = "^0.3.1"
|
||||
e2b-code-interpreter = {version = "^1.0.3", optional = true}
|
||||
anthropic = "^0.49.0"
|
||||
letta_client = "^0.1.147"
|
||||
letta_client = "^0.1.148"
|
||||
openai = "^1.60.0"
|
||||
opentelemetry-api = "1.30.0"
|
||||
opentelemetry-sdk = "1.30.0"
|
||||
|
||||
@@ -407,18 +407,6 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
files[0]
|
||||
|
||||
# Check that file is opened initially
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
blocks = agent_state.memory.file_blocks
|
||||
print(f"Agent has {len(blocks)} file block(s)")
|
||||
if blocks:
|
||||
initial_content_length = len(blocks[0].value)
|
||||
print(f"Initial file content length: {initial_content_length} characters")
|
||||
print(f"First 100 chars of content: {blocks[0].value[:100]}...")
|
||||
assert initial_content_length > 10, f"Expected file content > 10 chars, got {initial_content_length}"
|
||||
print("✓ File appears to be initially loaded")
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
@@ -441,6 +429,56 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
|
||||
|
||||
sources_list = client.sources.list()
|
||||
assert len(sources_list) == 1
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/long_test.txt"
|
||||
print(f"Uploading file: {file_path}")
|
||||
|
||||
# Upload the files
|
||||
with open(file_path, "rb") as f:
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
print(f"File upload job created with ID: {job.id}, initial status: {job.status}")
|
||||
|
||||
# Wait for the jobs to complete
|
||||
while job.status != "completed":
|
||||
print(f"Waiting for job {job.id} to complete... Current status: {job.status}")
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=f"Use ONLY the grep tool to search for `Nunzia De Girolamo`.")],
|
||||
)
|
||||
print(f"Grep request sent, got {len(search_files_response.messages)} message(s) in response")
|
||||
print(search_files_response.messages)
|
||||
|
||||
# Check that archival_memory_search was called
|
||||
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "grep" for tc in tool_calls), "search_files not called"
|
||||
|
||||
# Check it returned successfully
|
||||
tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_returns) > 0, "No tool returns found"
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
|
||||
@@ -500,11 +538,11 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
|
||||
assert (
|
||||
block.value
|
||||
== """
|
||||
[Viewing lines 50 to 55 (out of 100 lines)]
|
||||
Line 50: Line 51
|
||||
Line 51: Line 52
|
||||
Line 52: Line 53
|
||||
Line 53: Line 54
|
||||
Line 54: Line 55
|
||||
[Viewing lines 50 to 54 (out of 100 lines)]
|
||||
50: Line 51
|
||||
51: Line 52
|
||||
52: Line 53
|
||||
53: Line 54
|
||||
54: Line 55
|
||||
""".strip()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user