feat: add optional embedding_config parameter to file upload endpoint (#2901)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Shangyin Tan
2025-06-18 18:10:48 -07:00
committed by GitHub
parent 40629285fc
commit c6b41fe379
6 changed files with 51 additions and 61 deletions

View File

@@ -11,6 +11,7 @@ from starlette import status
import letta.constants as constants
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.passage import Passage
@@ -189,7 +190,7 @@ async def upload_file_to_source(
raw_ct = file.content_type or ""
media_type = raw_ct.split(";", 1)[0].strip().lower()
# If client didnt supply a Content-Type or its not one of the allowed types,
# If client didn't supply a Content-Type or it's not one of the allowed types,
# attempt to infer from filename extension.
if media_type not in allowed_media_types and file.filename:
guessed, _ = mimetypes.guess_type(file.filename)
@@ -216,6 +217,7 @@ async def upload_file_to_source(
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if source is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
content = await file.read()
# sanitize filename
@@ -249,7 +251,7 @@ async def upload_file_to_source(
# Use cloud processing for all files (simple files always, complex files with Mistral key)
logger.info("Running experimental cloud based file processing...")
safe_create_task(
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor),
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor, source.embedding_config),
logger=logger,
label="file_processor.process",
)
@@ -347,10 +349,17 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
async def load_file_to_source_cloud(
server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User
server: SyncServer,
agent_states: List[AgentState],
content: bytes,
file: UploadFile,
job: Job,
source_id: str,
actor: User,
embedding_config: EmbeddingConfig,
):
file_processor = MistralFileParser()
text_chunker = LlamaIndexChunker()
embedder = OpenAIEmbedder()
text_chunker = LlamaIndexChunker(chunk_size=embedding_config.embedding_chunk_size)
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job)

View File

@@ -239,14 +239,16 @@ class ToolManager:
@enforce_types
@trace_method
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
async def list_tools_async(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, upsert_base_tools: bool = True
) -> List[PydanticTool]:
"""List all tools with optional pagination."""
tools = await self._list_tools_async(actor=actor, after=after, limit=limit)
# Check if all base tools are present if we requested all the tools w/o cursor
# TODO: This is a temporary hack to resolve this issue
# TODO: This requires a deeper rethink about how we keep all our internal tools up-to-date
if not after:
if not after and upsert_base_tools:
existing_tool_names = {tool.name for tool in tools}
missing_base_tools = LETTA_TOOL_SET - existing_tool_names

View File

@@ -30,9 +30,12 @@ from letta_client.types import (
)
from letta.llm_api.openai_client import is_openai_reasoning_model
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
logger = get_logger(__name__)
# ------------------------------
# Helper Functions and Constants
# ------------------------------
@@ -443,7 +446,10 @@ def agent_state(client: Letta) -> AgentState:
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
try:
client.agents.delete(agent_state_instance.id)
except Exception as e:
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
@pytest.fixture(scope="function")
@@ -462,7 +468,10 @@ def agent_state_no_tools(client: Letta) -> AgentState:
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
try:
client.agents.delete(agent_state_instance.id)
except Exception as e:
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
# ------------------------------

View File

@@ -1,6 +1,5 @@
import os
import subprocess
import sys
import threading
from unittest.mock import MagicMock
import pytest
@@ -101,44 +100,15 @@ def _run_server():
@pytest.fixture(scope="module")
def server_url():
"""
Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI,
so its event loop and DB pool stay completely isolated from pytest-asyncio.
"""
url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283")
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL
if not os.getenv("LETTA_SERVER_URL"):
# Build the command to launch uvicorn on your FastAPI app
cmd = [
sys.executable,
"-m",
"uvicorn",
"letta.server.rest_api.app:app",
"--host",
"127.0.0.1",
"--port",
"8283",
]
# If you need TLS or reload settings from start_server(), you can add
# "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well.
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
wait_for_server(url) # Allow server startup time
server_proc = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# wait until the HTTP port is accepting connections
wait_for_server(url)
yield url
# Teardown: kill the subprocess if we started it
server_proc.terminate()
server_proc.wait(timeout=10)
else:
yield url
return url
@pytest.fixture(scope="module")
@@ -521,7 +491,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
assert voice_agent.enable_sleeptime == True
main_agent_tools = [tool.name for tool in voice_agent.tools]
assert len(main_agent_tools) == 2
assert len(main_agent_tools) == 4
assert "send_message" in main_agent_tools
assert "search_memory" in main_agent_tools
assert "core_memory_append" not in main_agent_tools

View File

@@ -107,8 +107,8 @@ def agent_state(client):
@pytest.mark.asyncio
async def test_sse_mcp_server(client, agent_state):
mcp_server_name = "github_composio"
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8/sse?useComposioHelperActions=true"
mcp_server_name = "deepwiki"
server_url = "https://mcp.deepwiki.com/sse"
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
client.tools.add_mcp_server(request=sse_mcp_config)
@@ -120,12 +120,15 @@ async def test_sse_mcp_server(client, agent_state):
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
assert len(tools) > 0
assert isinstance(tools[0], McpTool)
star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None)
# Check that one of the tools are executable
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name)
# Test with the ask_question tool which is one of the available deepwiki tools
ask_question_tool = next((t for t in tools if t.name == "ask_question"), None)
assert ask_question_tool is not None, f"ask_question tool not found. Available tools: {[t.name for t in tools]}"
tool_args = {"owner": "letta-ai", "repo": "letta"}
# Check that the tool is executable
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=ask_question_tool.name)
tool_args = {"repoName": "facebook/react", "question": "What is React?"}
# Add to agent, have agent invoke tool
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
@@ -141,18 +144,15 @@ async def test_sse_mcp_server(client, agent_state):
seq = response.messages
calls = [m for m in seq if isinstance(m, ToolCallMessage)]
assert calls, "Expected a ToolCallMessage"
assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"
assert calls[0].tool_call.name == "ask_question"
returns = [m for m in seq if isinstance(m, ToolReturnMessage)]
assert returns, "Expected a ToolReturnMessage"
tr = returns[0]
# status field
assert tr.status == "success", f"Bad status: {tr.status}"
# parse JSON payload
full_payload = json.loads(tr.tool_return)
assert full_payload.get("successful", False), f"Tool returned failure payload: {full_payload}"
assert full_payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {full_payload}"
# Check that we got some content back
assert len(tr.tool_return.strip()) > 0, f"Expected non-empty tool return, got: {tr.tool_return}"
def test_stdio_mcp_server(client, agent_state):

View File

@@ -2912,7 +2912,7 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
@pytest.mark.asyncio
async def test_list_tools(server: SyncServer, print_tool, default_user, event_loop):
# List tools (should include the one created by the fixture)
tools = await server.tool_manager.list_tools_async(actor=default_user)
tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False)
# Assertions to check that the created tool is listed
assert len(tools) == 1
@@ -3041,7 +3041,7 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e
# Delete the print_tool using the manager method
server.tool_manager.delete_tool_by_id(print_tool.id, actor=default_user)
tools = await server.tool_manager.list_tools_async(actor=default_user)
tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False)
assert len(tools) == 0