feat: Add optional lines param to grep tool (#2914)
This commit is contained in:
@@ -32,16 +32,23 @@ async def close_file(agent_state: "AgentState", file_name: str) -> str:
|
||||
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
|
||||
|
||||
|
||||
async def grep(agent_state: "AgentState", pattern: str, include: Optional[str] = None) -> str:
|
||||
async def grep(
|
||||
agent_state: "AgentState",
|
||||
pattern: str,
|
||||
include: Optional[str] = None,
|
||||
context_lines: Optional[int] = 3,
|
||||
) -> str:
|
||||
"""
|
||||
Grep tool to search files across data sources with a keyword or regex pattern.
|
||||
Grep tool to search files across data sources using a keyword or regex pattern.
|
||||
|
||||
Args:
|
||||
pattern (str): Keyword or regex pattern to search within file contents.
|
||||
include (Optional[str]): Optional keyword or regex pattern to filter filenames to include in the search.
|
||||
context_lines (Optional[int]): Number of lines of context to show before and after each match.
|
||||
Equivalent to `-C` in grep. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
str: Matching lines or summary output.
|
||||
str: Matching lines with optional surrounding context or a summary output.
|
||||
"""
|
||||
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -97,6 +97,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
stderr=[get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))],
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def open_file(self, agent_state: AgentState, file_name: str, view_range: Optional[Tuple[int, int]] = None) -> str:
|
||||
"""Stub for open_file tool."""
|
||||
start, end = None, None
|
||||
@@ -131,6 +132,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
)
|
||||
return f"Successfully opened file {file_name}, lines {start} to {end} are now visible in memory block <{file_name}>"
|
||||
|
||||
@trace_method
|
||||
async def close_file(self, agent_state: AgentState, file_name: str) -> str:
|
||||
"""Stub for close_file tool."""
|
||||
await self.files_agents_manager.update_file_agent_by_name(
|
||||
@@ -149,32 +151,52 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern: {e}")
|
||||
|
||||
def _get_context_lines(self, text: str, file_metadata: FileMetadata, match_line_idx: int, total_lines: int) -> List[str]:
|
||||
"""Get context lines around a match using LineChunker."""
|
||||
start_idx = max(0, match_line_idx - self.MAX_CONTEXT_LINES)
|
||||
end_idx = min(total_lines, match_line_idx + self.MAX_CONTEXT_LINES + 1)
|
||||
def _get_context_lines(
|
||||
self,
|
||||
formatted_lines: List[str],
|
||||
match_line_num: int,
|
||||
context_lines: int,
|
||||
) -> List[str]:
|
||||
"""Get context lines around a match from already-chunked lines.
|
||||
|
||||
# Use LineChunker to get formatted lines with numbers
|
||||
chunker = LineChunker()
|
||||
context_lines = chunker.chunk_text(text, file_metadata=file_metadata, start=start_idx, end=end_idx, add_metadata=False)
|
||||
Args:
|
||||
formatted_lines: Already chunked lines from LineChunker (format: "line_num: content")
|
||||
match_line_num: The 1-based line number of the match
|
||||
context_lines: Number of context lines before and after
|
||||
"""
|
||||
if not formatted_lines or context_lines < 0:
|
||||
return []
|
||||
|
||||
# Add match indicator
|
||||
formatted_lines = []
|
||||
for line in context_lines:
|
||||
# Find the index of the matching line in the formatted_lines list
|
||||
match_formatted_idx = None
|
||||
for i, line in enumerate(formatted_lines):
|
||||
if line and ":" in line:
|
||||
line_num_str = line.split(":")[0].strip()
|
||||
try:
|
||||
line_num = int(line_num_str)
|
||||
prefix = ">" if line_num == match_line_idx + 1 else " "
|
||||
formatted_lines.append(f"{prefix} {line}")
|
||||
line_num = int(line.split(":", 1)[0].strip())
|
||||
if line_num == match_line_num:
|
||||
match_formatted_idx = i
|
||||
break
|
||||
except ValueError:
|
||||
formatted_lines.append(f" {line}")
|
||||
else:
|
||||
formatted_lines.append(f" {line}")
|
||||
continue
|
||||
|
||||
return formatted_lines
|
||||
if match_formatted_idx is None:
|
||||
return []
|
||||
|
||||
async def grep(self, agent_state: AgentState, pattern: str, include: Optional[str] = None) -> str:
|
||||
# Calculate context range with bounds checking
|
||||
start_idx = max(0, match_formatted_idx - context_lines)
|
||||
end_idx = min(len(formatted_lines), match_formatted_idx + context_lines + 1)
|
||||
|
||||
# Extract context lines and add match indicator
|
||||
context_lines_with_indicator = []
|
||||
for i in range(start_idx, end_idx):
|
||||
line = formatted_lines[i]
|
||||
prefix = ">" if i == match_formatted_idx else " "
|
||||
context_lines_with_indicator.append(f"{prefix} {line}")
|
||||
|
||||
return context_lines_with_indicator
|
||||
|
||||
@trace_method
|
||||
async def grep(self, agent_state: AgentState, pattern: str, include: Optional[str] = None, context_lines: Optional[int] = 3) -> str:
|
||||
"""
|
||||
Search for pattern in all attached files and return matches with context.
|
||||
|
||||
@@ -182,6 +204,8 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
agent_state: Current agent state
|
||||
pattern: Regular expression pattern to search for
|
||||
include: Optional pattern to filter filenames to include in the search
|
||||
context_lines (Optional[int]): Number of lines of context to show before and after each match.
|
||||
Equivalent to `-C` in grep. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
Formatted string with search results, file names, line numbers, and context
|
||||
@@ -277,6 +301,21 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
if formatted_lines and formatted_lines[0].startswith("[Viewing"):
|
||||
formatted_lines = formatted_lines[1:]
|
||||
|
||||
# Convert 0-based line numbers to 1-based for grep compatibility
|
||||
corrected_lines = []
|
||||
for line in formatted_lines:
|
||||
if line and ":" in line:
|
||||
try:
|
||||
line_parts = line.split(":", 1)
|
||||
line_num = int(line_parts[0].strip())
|
||||
line_content = line_parts[1] if len(line_parts) > 1 else ""
|
||||
corrected_lines.append(f"{line_num + 1}:{line_content}")
|
||||
except (ValueError, IndexError):
|
||||
corrected_lines.append(line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
formatted_lines = corrected_lines
|
||||
|
||||
# Search for matches in formatted lines
|
||||
for formatted_line in formatted_lines:
|
||||
if total_matches >= self.MAX_TOTAL_MATCHES:
|
||||
@@ -297,12 +336,11 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
continue
|
||||
|
||||
if pattern_regex.search(line_content):
|
||||
# Get context around the match (convert back to 0-based indexing)
|
||||
context_lines = self._get_context_lines(file.content, file, line_num - 1, len(file.content.splitlines()))
|
||||
context = self._get_context_lines(formatted_lines, match_line_num=line_num, context_lines=context_lines or 0)
|
||||
|
||||
# Format the match result
|
||||
match_header = f"\n=== {file.file_name}:{line_num} ==="
|
||||
match_content = "\n".join(context_lines)
|
||||
match_content = "\n".join(context)
|
||||
results.append(f"{match_header}\n{match_content}")
|
||||
|
||||
file_matches += 1
|
||||
@@ -340,6 +378,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
@trace_method
|
||||
async def search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
|
||||
"""
|
||||
Search for text within attached files using semantic search and return passages with their source filenames.
|
||||
|
||||
2431
tests/data/list_tools.json
Normal file
2431
tests/data/list_tools.json
Normal file
File diff suppressed because one or more lines are too long
@@ -438,7 +438,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -477,7 +477,7 @@ def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentSta
|
||||
print(f"Grep request sent, got {len(search_files_response.messages)} message(s) in response")
|
||||
print(search_files_response.messages)
|
||||
|
||||
# Check that archival_memory_search was called
|
||||
# Check that grep was called
|
||||
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "grep" for tc in tool_calls), "search_files not called"
|
||||
@@ -488,6 +488,64 @@ def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentSta
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
sources_list = client.sources.list()
|
||||
assert len(sources_list) == 1
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/list_tools.json"
|
||||
print(f"Uploading file: {file_path}")
|
||||
|
||||
# Upload the files
|
||||
with open(file_path, "rb") as f:
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
print(f"File upload job created with ID: {job.id}, initial status: {job.status}")
|
||||
|
||||
# Wait for the jobs to complete
|
||||
while job.status != "completed":
|
||||
print(f"Waiting for job {job.id} to complete... Current status: {job.status}")
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=f"Use ONLY the grep tool to search for `tool-f5b80b08-5a45-4a0a-b2cd-dd8a0177b7ef`.")],
|
||||
)
|
||||
print(f"Grep request sent, got {len(search_files_response.messages)} message(s) in response")
|
||||
print(search_files_response.messages)
|
||||
|
||||
tool_return_message = next((m for m in search_files_response.messages if m.message_type == "tool_return_message"), None)
|
||||
assert tool_return_message is not None, "No ToolReturnMessage found in messages"
|
||||
|
||||
# Basic structural integrity checks
|
||||
assert tool_return_message.name == "grep"
|
||||
assert tool_return_message.status == "success"
|
||||
assert "Found 1 matches" in tool_return_message.tool_return
|
||||
assert "tool-f5b80b08-5a45-4a0a-b2cd-dd8a0177b7ef" in tool_return_message.tool_return
|
||||
|
||||
# Context line integrity (3 lines before and after)
|
||||
assert "507:" in tool_return_message.tool_return
|
||||
assert "508:" in tool_return_message.tool_return
|
||||
assert "509:" in tool_return_message.tool_return
|
||||
assert "> 510:" in tool_return_message.tool_return # Match line with > prefix
|
||||
assert "511:" in tool_return_message.tool_return
|
||||
assert "512:" in tool_return_message.tool_return
|
||||
assert "513:" in tool_return_message.tool_return
|
||||
|
||||
|
||||
def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
Reference in New Issue
Block a user