feat: Add close all files functionality (#3139)

This commit is contained in:
Matthew Zhou
2025-07-02 14:27:38 -07:00
committed by GitHub
parent 3e885a4ef7
commit 243d3d040b
5 changed files with 32 additions and 13 deletions

View File

@@ -131,7 +131,7 @@ MEMORY_TOOLS_LINE_NUMBER_PREFIX_REGEX = re.compile(
BUILTIN_TOOLS = ["run_code", "web_search"]
# Built in tools
FILES_TOOLS = ["open_files", "grep_files", "search_files"]
FILES_TOOLS = ["open_files", "grep_files", "semantic_search_files"]
FILE_MEMORY_EXISTS_MESSAGE = "The following files are currently accessible in memory:"
FILE_MEMORY_EMPTY_MESSAGE = (

View File

@@ -65,12 +65,12 @@ async def grep_files(
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
async def search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
async def semantic_search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
"""
Get list of most relevant chunks from any file using embedding search.
Get list of most relevant chunks from any file using vector/embedding search.
Use this when you want to:
- Find related content that may not match exact keywords (e.g., conceptually similar sections)
- Find related content that without using exact keywords (e.g., conceptually similar sections)
- Look up high-level descriptions, documentation, or config patterns
- Perform fuzzy search when grep isn't sufficient

View File

@@ -366,6 +366,23 @@ async def detach_source(
return agent_state
@router.patch("/{agent_id}/files/close-all", response_model=List[str], operation_id="close_all_open_files")
async def close_all_open_files(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Closes all currently open files for a given agent.
This endpoint updates the file state for the agent so that no files are marked as open.
Typically used to reset the working memory view for the agent.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.file_agent_manager.close_all_other_files(agent_id=agent_id, keep_file_names=[], actor=actor)
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
async def retrieve_agent(
agent_id: str,

View File

@@ -76,7 +76,7 @@ class LettaFileToolExecutor(ToolExecutor):
function_map = {
"open_files": self.open_files,
"grep_files": self.grep_files,
"search_files": self.search_files,
"semantic_search_files": self.semantic_search_files,
}
if function_name not in function_map:
@@ -463,7 +463,7 @@ class LettaFileToolExecutor(ToolExecutor):
return "\n".join(formatted_results)
@trace_method
async def search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
"""
Search for text within attached files using semantic search and return passages with their source filenames.

View File

@@ -72,7 +72,7 @@ def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str,
@pytest.fixture
def agent_state(client: LettaSDKClient):
open_file_tool = client.tools.list(name="open_files")[0]
search_files_tool = client.tools.list(name="search_files")[0]
search_files_tool = client.tools.list(name="semantic_search_files")[0]
grep_tool = client.tools.list(name="grep_files")[0]
agent_state = client.agents.create(
@@ -400,11 +400,13 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
assert len(files) == 1
assert files[0].source_id == source.id
# Ask agent to use the search_files tool
# Ask agent to use the semantic_search_files tool
search_files_response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(role="user", content=f"Use ONLY the search_files tool to search for details regarding the electoral history.")
MessageCreate(
role="user", content=f"Use ONLY the semantic_search_files tool to search for details regarding the electoral history."
)
],
)
print(f"Search file request sent, got {len(search_files_response.messages)} message(s) in response")
@@ -413,7 +415,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
# Check that archival_memory_search was called
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
assert len(tool_calls) > 0, "No tool calls found"
assert any(tc.tool_call.name == "search_files" for tc in tool_calls), "search_files not called"
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
# Check it returned successfully
tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"]
@@ -444,7 +446,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
assert len(files) == 1
assert files[0].source_id == source.id
# Ask agent to use the search_files tool
# Ask agent to use the semantic_search_files tool
search_files_response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[MessageCreate(role="user", content=f"Use ONLY the grep_files tool to search for `Nunzia De Girolamo`.")],
@@ -455,7 +457,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
# Check that grep_files was called
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
assert len(tool_calls) > 0, "No tool calls found"
assert any(tc.tool_call.name == "grep_files" for tc in tool_calls), "search_files not called"
assert any(tc.tool_call.name == "grep_files" for tc in tool_calls), "semantic_search_files not called"
# Check it returned successfully
tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"]
@@ -486,7 +488,7 @@ def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state:
assert len(files) == 1
assert files[0].source_id == source.id
# Ask agent to use the search_files tool
# Ask agent to use the semantic_search_files tool
search_files_response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[