feat: Add close all files functionality (#3139)
This commit is contained in:
@@ -131,7 +131,7 @@ MEMORY_TOOLS_LINE_NUMBER_PREFIX_REGEX = re.compile(
|
||||
BUILTIN_TOOLS = ["run_code", "web_search"]
|
||||
|
||||
# Built in tools
|
||||
FILES_TOOLS = ["open_files", "grep_files", "search_files"]
|
||||
FILES_TOOLS = ["open_files", "grep_files", "semantic_search_files"]
|
||||
|
||||
FILE_MEMORY_EXISTS_MESSAGE = "The following files are currently accessible in memory:"
|
||||
FILE_MEMORY_EMPTY_MESSAGE = (
|
||||
|
||||
@@ -65,12 +65,12 @@ async def grep_files(
|
||||
raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
|
||||
|
||||
|
||||
async def search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
|
||||
async def semantic_search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
|
||||
"""
|
||||
Get list of most relevant chunks from any file using embedding search.
|
||||
Get list of most relevant chunks from any file using vector/embedding search.
|
||||
|
||||
Use this when you want to:
|
||||
- Find related content that may not match exact keywords (e.g., conceptually similar sections)
|
||||
- Find related content that without using exact keywords (e.g., conceptually similar sections)
|
||||
- Look up high-level descriptions, documentation, or config patterns
|
||||
- Perform fuzzy search when grep isn't sufficient
|
||||
|
||||
|
||||
@@ -366,6 +366,23 @@ async def detach_source(
|
||||
return agent_state
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/files/close-all", response_model=List[str], operation_id="close_all_open_files")
|
||||
async def close_all_open_files(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Closes all currently open files for a given agent.
|
||||
|
||||
This endpoint updates the file state for the agent so that no files are marked as open.
|
||||
Typically used to reset the working memory view for the agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return server.file_agent_manager.close_all_other_files(agent_id=agent_id, keep_file_names=[], actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
|
||||
async def retrieve_agent(
|
||||
agent_id: str,
|
||||
|
||||
@@ -76,7 +76,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
function_map = {
|
||||
"open_files": self.open_files,
|
||||
"grep_files": self.grep_files,
|
||||
"search_files": self.search_files,
|
||||
"semantic_search_files": self.semantic_search_files,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
@@ -463,7 +463,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
@trace_method
|
||||
async def search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
|
||||
async def semantic_search_files(self, agent_state: AgentState, query: str, limit: int = 10) -> str:
|
||||
"""
|
||||
Search for text within attached files using semantic search and return passages with their source filenames.
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def upload_file_and_wait(client: LettaSDKClient, source_id: str, file_path: str,
|
||||
@pytest.fixture
|
||||
def agent_state(client: LettaSDKClient):
|
||||
open_file_tool = client.tools.list(name="open_files")[0]
|
||||
search_files_tool = client.tools.list(name="search_files")[0]
|
||||
search_files_tool = client.tools.list(name="semantic_search_files")[0]
|
||||
grep_tool = client.tools.list(name="grep_files")[0]
|
||||
|
||||
agent_state = client.agents.create(
|
||||
@@ -400,11 +400,13 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
# Ask agent to use the semantic_search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(role="user", content=f"Use ONLY the search_files tool to search for details regarding the electoral history.")
|
||||
MessageCreate(
|
||||
role="user", content=f"Use ONLY the semantic_search_files tool to search for details regarding the electoral history."
|
||||
)
|
||||
],
|
||||
)
|
||||
print(f"Search file request sent, got {len(search_files_response.messages)} message(s) in response")
|
||||
@@ -413,7 +415,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state:
|
||||
# Check that archival_memory_search was called
|
||||
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "search_files" for tc in tool_calls), "search_files not called"
|
||||
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
|
||||
|
||||
# Check it returned successfully
|
||||
tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"]
|
||||
@@ -444,7 +446,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
# Ask agent to use the semantic_search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=f"Use ONLY the grep_files tool to search for `Nunzia De Girolamo`.")],
|
||||
@@ -455,7 +457,7 @@ def test_agent_uses_grep_correctly_basic(client: LettaSDKClient, agent_state: Ag
|
||||
# Check that grep_files was called
|
||||
tool_calls = [msg for msg in search_files_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "grep_files" for tc in tool_calls), "search_files not called"
|
||||
assert any(tc.tool_call.name == "grep_files" for tc in tool_calls), "semantic_search_files not called"
|
||||
|
||||
# Check it returned successfully
|
||||
tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"]
|
||||
@@ -486,7 +488,7 @@ def test_agent_uses_grep_correctly_advanced(client: LettaSDKClient, agent_state:
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Ask agent to use the search_files tool
|
||||
# Ask agent to use the semantic_search_files tool
|
||||
search_files_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
|
||||
Reference in New Issue
Block a user