feat: Add optional lines param to grep tool (#2914)
This commit is contained in:
@@ -438,7 +438,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -477,7 +477,7 @@ def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentSta
|
||||
print(f"Grep request sent, got {len(search_files_response.messages)} message(s) in response")
|
||||
print(search_files_response.messages)
|
||||
|
||||
# Check that archival_memory_search was called
|
||||
# Check that grep was called
|
||||
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "grep" for tc in tool_calls), "search_files not called"
|
||||
@@ -488,6 +488,64 @@ def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentSta
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
sources_list = client.sources.list()
|
||||
assert len(sources_list) == 1
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/list_tools.json"
|
||||
print(f"Uploading file: {file_path}")
|
||||
|
||||
# Upload the files
|
||||
with open(file_path, "rb") as f:
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
print(f"File upload job created with ID: {job.id}, initial status: {job.status}")
|
||||
|
||||
# Wait for the jobs to complete
|
||||
while job.status != "completed":
|
||||
print(f"Waiting for job {job.id} to complete... Current status: {job.status}")
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=f"Use ONLY the grep tool to search for `tool-f5b80b08-5a45-4a0a-b2cd-dd8a0177b7ef`.")],
|
||||
)
|
||||
print(f"Grep request sent, got {len(search_files_response.messages)} message(s) in response")
|
||||
print(search_files_response.messages)
|
||||
|
||||
tool_return_message = next((m for m in search_files_response.messages if m.message_type == "tool_return_message"), None)
|
||||
assert tool_return_message is not None, "No ToolReturnMessage found in messages"
|
||||
|
||||
# Basic structural integrity checks
|
||||
assert tool_return_message.name == "grep"
|
||||
assert tool_return_message.status == "success"
|
||||
assert "Found 1 matches" in tool_return_message.tool_return
|
||||
assert "tool-f5b80b08-5a45-4a0a-b2cd-dd8a0177b7ef" in tool_return_message.tool_return
|
||||
|
||||
# Context line integrity (3 lines before and after)
|
||||
assert "507:" in tool_return_message.tool_return
|
||||
assert "508:" in tool_return_message.tool_return
|
||||
assert "509:" in tool_return_message.tool_return
|
||||
assert "> 510:" in tool_return_message.tool_return # Match line with > prefix
|
||||
assert "511:" in tool_return_message.tool_return
|
||||
assert "512:" in tool_return_message.tool_return
|
||||
assert "513:" in tool_return_message.tool_return
|
||||
|
||||
|
||||
def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
Reference in New Issue
Block a user