feat: add optional embedding_config parameter to file upload endpoint (#2901)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from starlette import status
|
||||
import letta.constants as constants
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -189,7 +190,7 @@ async def upload_file_to_source(
|
||||
raw_ct = file.content_type or ""
|
||||
media_type = raw_ct.split(";", 1)[0].strip().lower()
|
||||
|
||||
# If client didn’t supply a Content-Type or it’s not one of the allowed types,
|
||||
# If client didn't supply a Content-Type or it's not one of the allowed types,
|
||||
# attempt to infer from filename extension.
|
||||
if media_type not in allowed_media_types and file.filename:
|
||||
guessed, _ = mimetypes.guess_type(file.filename)
|
||||
@@ -216,6 +217,7 @@ async def upload_file_to_source(
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
if source is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
|
||||
|
||||
content = await file.read()
|
||||
|
||||
# sanitize filename
|
||||
@@ -249,7 +251,7 @@ async def upload_file_to_source(
|
||||
# Use cloud processing for all files (simple files always, complex files with Mistral key)
|
||||
logger.info("Running experimental cloud based file processing...")
|
||||
safe_create_task(
|
||||
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor),
|
||||
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor, source.embedding_config),
|
||||
logger=logger,
|
||||
label="file_processor.process",
|
||||
)
|
||||
@@ -347,10 +349,17 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
|
||||
|
||||
|
||||
async def load_file_to_source_cloud(
|
||||
server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User
|
||||
server: SyncServer,
|
||||
agent_states: List[AgentState],
|
||||
content: bytes,
|
||||
file: UploadFile,
|
||||
job: Job,
|
||||
source_id: str,
|
||||
actor: User,
|
||||
embedding_config: EmbeddingConfig,
|
||||
):
|
||||
file_processor = MistralFileParser()
|
||||
text_chunker = LlamaIndexChunker()
|
||||
embedder = OpenAIEmbedder()
|
||||
text_chunker = LlamaIndexChunker(chunk_size=embedding_config.embedding_chunk_size)
|
||||
embedder = OpenAIEmbedder(embedding_config=embedding_config)
|
||||
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
|
||||
await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job)
|
||||
|
||||
@@ -239,14 +239,16 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
async def list_tools_async(
|
||||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, upsert_base_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
tools = await self._list_tools_async(actor=actor, after=after, limit=limit)
|
||||
|
||||
# Check if all base tools are present if we requested all the tools w/o cursor
|
||||
# TODO: This is a temporary hack to resolve this issue
|
||||
# TODO: This requires a deeper rethink about how we keep all our internal tools up-to-date
|
||||
if not after:
|
||||
if not after and upsert_base_tools:
|
||||
existing_tool_names = {tool.name for tool in tools}
|
||||
missing_base_tools = LETTA_TOOL_SET - existing_tool_names
|
||||
|
||||
|
||||
@@ -30,9 +30,12 @@ from letta_client.types import (
|
||||
)
|
||||
|
||||
from letta.llm_api.openai_client import is_openai_reasoning_model
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
@@ -443,7 +446,10 @@ def agent_state(client: Letta) -> AgentState:
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
try:
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -462,7 +468,10 @@ def agent_state_no_tools(client: Letta) -> AgentState:
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
try:
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -101,44 +100,15 @@ def _run_server():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url():
|
||||
"""
|
||||
Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI,
|
||||
so its event loop and DB pool stay completely isolated from pytest-asyncio.
|
||||
"""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283")
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
# Build the command to launch uvicorn on your FastAPI app
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"letta.server.rest_api.app:app",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"8283",
|
||||
]
|
||||
# If you need TLS or reload settings from start_server(), you can add
|
||||
# "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well.
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url) # Allow server startup time
|
||||
|
||||
server_proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
# wait until the HTTP port is accepting connections
|
||||
wait_for_server(url)
|
||||
|
||||
yield url
|
||||
|
||||
# Teardown: kill the subprocess if we started it
|
||||
server_proc.terminate()
|
||||
server_proc.wait(timeout=10)
|
||||
else:
|
||||
yield url
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -521,7 +491,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
|
||||
|
||||
assert voice_agent.enable_sleeptime == True
|
||||
main_agent_tools = [tool.name for tool in voice_agent.tools]
|
||||
assert len(main_agent_tools) == 2
|
||||
assert len(main_agent_tools) == 4
|
||||
assert "send_message" in main_agent_tools
|
||||
assert "search_memory" in main_agent_tools
|
||||
assert "core_memory_append" not in main_agent_tools
|
||||
|
||||
@@ -107,8 +107,8 @@ def agent_state(client):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_mcp_server(client, agent_state):
|
||||
mcp_server_name = "github_composio"
|
||||
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8/sse?useComposioHelperActions=true"
|
||||
mcp_server_name = "deepwiki"
|
||||
server_url = "https://mcp.deepwiki.com/sse"
|
||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
client.tools.add_mcp_server(request=sse_mcp_config)
|
||||
|
||||
@@ -120,12 +120,15 @@ async def test_sse_mcp_server(client, agent_state):
|
||||
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||
assert len(tools) > 0
|
||||
assert isinstance(tools[0], McpTool)
|
||||
star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None)
|
||||
|
||||
# Check that one of the tools are executable
|
||||
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name)
|
||||
# Test with the ask_question tool which is one of the available deepwiki tools
|
||||
ask_question_tool = next((t for t in tools if t.name == "ask_question"), None)
|
||||
assert ask_question_tool is not None, f"ask_question tool not found. Available tools: {[t.name for t in tools]}"
|
||||
|
||||
tool_args = {"owner": "letta-ai", "repo": "letta"}
|
||||
# Check that the tool is executable
|
||||
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=ask_question_tool.name)
|
||||
|
||||
tool_args = {"repoName": "facebook/react", "question": "What is React?"}
|
||||
|
||||
# Add to agent, have agent invoke tool
|
||||
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
|
||||
@@ -141,18 +144,15 @@ async def test_sse_mcp_server(client, agent_state):
|
||||
seq = response.messages
|
||||
calls = [m for m in seq if isinstance(m, ToolCallMessage)]
|
||||
assert calls, "Expected a ToolCallMessage"
|
||||
assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"
|
||||
assert calls[0].tool_call.name == "ask_question"
|
||||
|
||||
returns = [m for m in seq if isinstance(m, ToolReturnMessage)]
|
||||
assert returns, "Expected a ToolReturnMessage"
|
||||
tr = returns[0]
|
||||
# status field
|
||||
assert tr.status == "success", f"Bad status: {tr.status}"
|
||||
# parse JSON payload
|
||||
full_payload = json.loads(tr.tool_return)
|
||||
|
||||
assert full_payload.get("successful", False), f"Tool returned failure payload: {full_payload}"
|
||||
assert full_payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {full_payload}"
|
||||
# Check that we got some content back
|
||||
assert len(tr.tool_return.strip()) > 0, f"Expected non-empty tool return, got: {tr.tool_return}"
|
||||
|
||||
|
||||
def test_stdio_mcp_server(client, agent_state):
|
||||
|
||||
@@ -2912,7 +2912,7 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(server: SyncServer, print_tool, default_user, event_loop):
|
||||
# List tools (should include the one created by the fixture)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False)
|
||||
|
||||
# Assertions to check that the created tool is listed
|
||||
assert len(tools) == 1
|
||||
@@ -3041,7 +3041,7 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e
|
||||
# Delete the print_tool using the manager method
|
||||
server.tool_manager.delete_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user