From c6b41fe3790858bd91763d45bc54343233c57571 Mon Sep 17 00:00:00 2001 From: Shangyin Tan Date: Wed, 18 Jun 2025 18:10:48 -0700 Subject: [PATCH] feat: add optional embedding_config parameter to file upload endpoint (#2901) Co-authored-by: Matt Zhou --- letta/server/rest_api/routers/v1/sources.py | 19 ++++++--- letta/services/tool_manager.py | 6 ++- tests/integration_test_send_message.py | 13 +++++- tests/integration_test_voice_agent.py | 46 ++++----------------- tests/mcp/test_mcp.py | 24 +++++------ tests/test_managers.py | 4 +- 6 files changed, 51 insertions(+), 61 deletions(-) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 28179cbf..39427321 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -11,6 +11,7 @@ from starlette import status import letta.constants as constants from letta.log import get_logger from letta.schemas.agent import AgentState +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.passage import Passage @@ -189,7 +190,7 @@ async def upload_file_to_source( raw_ct = file.content_type or "" media_type = raw_ct.split(";", 1)[0].strip().lower() - # If client didn’t supply a Content-Type or it’s not one of the allowed types, + # If client didn't supply a Content-Type or it's not one of the allowed types, # attempt to infer from filename extension. if media_type not in allowed_media_types and file.filename: guessed, _ = mimetypes.guess_type(file.filename) @@ -216,6 +217,7 @@ async def upload_file_to_source( source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) if source is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.") + content = await file.read() # sanitize filename @@ -249,7 +251,7 @@ async def upload_file_to_source( # Use cloud processing for all files (simple files always, complex files with Mistral key) logger.info("Running experimental cloud based file processing...") safe_create_task( - load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor), + load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor, source.embedding_config), logger=logger, label="file_processor.process", ) @@ -347,10 +349,17 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac async def load_file_to_source_cloud( - server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User + server: SyncServer, + agent_states: List[AgentState], + content: bytes, + file: UploadFile, + job: Job, + source_id: str, + actor: User, + embedding_config: EmbeddingConfig, ): file_processor = MistralFileParser() - text_chunker = LlamaIndexChunker() - embedder = OpenAIEmbedder() + text_chunker = LlamaIndexChunker(chunk_size=embedding_config.embedding_chunk_size) + embedder = OpenAIEmbedder(embedding_config=embedding_config) file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor) await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job) diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 9e0d1a26..f73c7075 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -239,14 +239,16 @@ class ToolManager: @enforce_types @trace_method - async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: + async def list_tools_async( + self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, upsert_base_tools: bool = True + ) -> List[PydanticTool]: """List all tools with optional pagination.""" tools = await self._list_tools_async(actor=actor, after=after, limit=limit) # Check if all base tools are present if we requested all the tools w/o cursor # TODO: This is a temporary hack to resolve this issue # TODO: This requires a deeper rethink about how we keep all our internal tools up-to-date - if not after: + if not after and upsert_base_tools: existing_tool_names = {tool.name for tool in tools} missing_base_tools = LETTA_TOOL_SET - existing_tool_names diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index a469ca2d..3a848527 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -30,9 +30,12 @@ from letta_client.types import ( ) from letta.llm_api.openai_client import is_openai_reasoning_model +from letta.log import get_logger from letta.schemas.agent import AgentState from letta.schemas.llm_config import LLMConfig +logger = get_logger(__name__) + # ------------------------------ # Helper Functions and Constants # ------------------------------ @@ -443,7 +446,10 @@ def agent_state(client: Letta) -> AgentState: ) yield agent_state_instance - client.agents.delete(agent_state_instance.id) + try: + client.agents.delete(agent_state_instance.id) + except Exception as e: + logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}") @pytest.fixture(scope="function") @@ -462,7 +468,10 @@ def agent_state_no_tools(client: Letta) -> AgentState: ) yield agent_state_instance - client.agents.delete(agent_state_instance.id) + try: + client.agents.delete(agent_state_instance.id) + except Exception as e: + logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}") # ------------------------------ diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 1c61dcec..2da7d19f 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -1,6 +1,5 @@ import os -import subprocess -import sys +import threading from unittest.mock import MagicMock import pytest @@ -101,44 +100,15 @@ def _run_server(): @pytest.fixture(scope="module") def server_url(): - """ - Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI, - so its event loop and DB pool stay completely isolated from pytest-asyncio. - """ - url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283") + """Ensures a server is running and returns its base URL.""" + url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - # Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL if not os.getenv("LETTA_SERVER_URL"): - # Build the command to launch uvicorn on your FastAPI app - cmd = [ - sys.executable, - "-m", - "uvicorn", - "letta.server.rest_api.app:app", - "--host", - "127.0.0.1", - "--port", - "8283", - ] - # If you need TLS or reload settings from start_server(), you can add - # "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well. + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + wait_for_server(url) # Allow server startup time - server_proc = subprocess.Popen( - cmd, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - # wait until the HTTP port is accepting connections - wait_for_server(url) - - yield url - - # Teardown: kill the subprocess if we started it - server_proc.terminate() - server_proc.wait(timeout=10) - else: - yield url + return url @pytest.fixture(scope="module") @@ -521,7 +491,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor, server_url): assert voice_agent.enable_sleeptime == True main_agent_tools = [tool.name for tool in voice_agent.tools] - assert len(main_agent_tools) == 2 + assert len(main_agent_tools) == 4 assert "send_message" in main_agent_tools assert "search_memory" in main_agent_tools assert "core_memory_append" not in main_agent_tools diff --git a/tests/mcp/test_mcp.py b/tests/mcp/test_mcp.py index 41ee8e8c..d17b9f38 100644 --- a/tests/mcp/test_mcp.py +++ b/tests/mcp/test_mcp.py @@ -107,8 +107,8 @@ def agent_state(client): @pytest.mark.asyncio async def test_sse_mcp_server(client, agent_state): - mcp_server_name = "github_composio" - server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8/sse?useComposioHelperActions=true" + mcp_server_name = "deepwiki" + server_url = "https://mcp.deepwiki.com/sse" sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url) client.tools.add_mcp_server(request=sse_mcp_config) @@ -120,12 +120,15 @@ async def test_sse_mcp_server(client, agent_state): tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name) assert len(tools) > 0 assert isinstance(tools[0], McpTool) - star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None) - # Check that one of the tools are executable - letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name) + # Test with the ask_question tool which is one of the available deepwiki tools + ask_question_tool = next((t for t in tools if t.name == "ask_question"), None) + assert ask_question_tool is not None, f"ask_question tool not found. Available tools: {[t.name for t in tools]}" - tool_args = {"owner": "letta-ai", "repo": "letta"} + # Check that the tool is executable + letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=ask_question_tool.name) + + tool_args = {"repoName": "facebook/react", "question": "What is React?"} # Add to agent, have agent invoke tool client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id) @@ -141,18 +144,15 @@ async def test_sse_mcp_server(client, agent_state): seq = response.messages calls = [m for m in seq if isinstance(m, ToolCallMessage)] assert calls, "Expected a ToolCallMessage" - assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER" + assert calls[0].tool_call.name == "ask_question" returns = [m for m in seq if isinstance(m, ToolReturnMessage)] assert returns, "Expected a ToolReturnMessage" tr = returns[0] # status field assert tr.status == "success", f"Bad status: {tr.status}" - # parse JSON payload - full_payload = json.loads(tr.tool_return) - - assert full_payload.get("successful", False), f"Tool returned failure payload: {full_payload}" - assert full_payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {full_payload}" + # Check that we got some content back + assert len(tr.tool_return.strip()) > 0, f"Expected non-empty tool return, got: {tr.tool_return}" def test_stdio_mcp_server(client, agent_state): diff --git a/tests/test_managers.py b/tests/test_managers.py index a345b0e0..cb972432 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2912,7 +2912,7 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user): @pytest.mark.asyncio async def test_list_tools(server: SyncServer, print_tool, default_user, event_loop): # List tools (should include the one created by the fixture) - tools = await server.tool_manager.list_tools_async(actor=default_user) + tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False) # Assertions to check that the created tool is listed assert len(tools) == 1 @@ -3041,7 +3041,7 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e # Delete the print_tool using the manager method server.tool_manager.delete_tool_by_id(print_tool.id, actor=default_user) - tools = await server.tool_manager.list_tools_async(actor=default_user) + tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False) assert len(tools) == 0