Files
letta-server/tests/integration_test_pinecone_tool.py
Matthew Zhou 82fc01ed04 feat: Adjust import/export agent endpoints to accept new agent file schema (#3506)
Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Shubham Naik <shub@letta.com>
2025-08-12 11:18:56 -07:00

215 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import json
import os
import threading
import time
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
from letta_client.core import RequestOptions
REASONING_THROTTLE_MS = 100
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="function")
def client(server_url: str):
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
async_client_instance = AsyncLetta(base_url=server_url)
yield async_client_instance
async def test_pinecone_tool(client: AsyncLetta) -> None:
"""
Test the Pinecone tool integration with the Letta client.
"""
with open("../../scripts/test-afs/knowledge-base.af", "rb") as f:
response = await client.agents.import_file(file=f)
agent_id = response.agent_ids[0]
agent = await client.agents.modify(
agent_id=agent_id,
tool_exec_environment_variables={
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
"PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"),
},
)
last_message = await client.agents.messages.list(
agent_id=agent.id,
limit=1,
)
curr_message_type = None
messages = []
reasoning_content = []
last_reasoning_update_ms = 0
tool_call_content = ""
tool_return_content = ""
summary = None
pinecone_results = None
queries = []
try:
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content=TEST_USER_MESSAGE,
),
],
stream_tokens=True,
request_options=RequestOptions(
timeout_in_seconds=1000,
),
)
async for chunk in response:
if chunk.message_type != curr_message_type:
messages.append(chunk)
curr_message_type = chunk.message_type
if curr_message_type == "reasoning_message":
reasoning_content = []
if curr_message_type == "tool_call_message":
tool_call_content = ""
if chunk.message_type == "reasoning_message":
now_ms = time.time_ns() // 1_000_000
if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS:
await asyncio.sleep(REASONING_THROTTLE_MS / 1000)
last_reasoning_update_ms = now_ms
if len(reasoning_content) == 0:
reasoning_content = [chunk.reasoning]
else:
reasoning_content[-1] += chunk.reasoning
message_dict = messages[-1].model_dump()
message_dict["reasoning"] = "".join(reasoning_content).strip()
messages[-1] = ReasoningMessage(**message_dict)
if chunk.message_type == "tool_return_message":
tool_return_content += chunk.tool_return
if chunk.status == "success":
try:
if chunk.name == "summarize_pinecone_results":
json_response = json.loads(chunk.tool_return)
summary = json_response.get("summary", None)
pinecone_results = json_response.get("pinecone_results", None)
tool_return_content = ""
elif chunk.name == "craft_queries":
queries.append(chunk.tool_return)
tool_return_content = ""
except Exception as e:
print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n")
tool_return_content = ""
if chunk.message_type == "tool_call_message":
if chunk.tool_call.arguments is not None:
tool_call_content += chunk.tool_call.arguments
message_dict = messages[-1].model_dump()
message_dict["tool_call"]["arguments"] = tool_call_content
messages[-1] = ToolCallMessage(**message_dict)
except Exception as e:
print(f"Failed to fetch knowledge base response: {str(e)}\n")
print(tool_call_content)
raise e
assert len(messages) > 0, "No messages received from the agent."
assert len(reasoning_content) > 0, "No reasoning content received from the agent."
assert summary is not None, "No summary received from the agent."
assert pinecone_results is not None, "No Pinecone results received from the agent."
assert len(queries) > 0, "No queries received from the agent."
assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason."
assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats."
response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]]
response_message_types_from_stream = [m.message_type for m in response_messages_from_stream]
messages_from_db = await client.agents.messages.list(
agent_id=agent.id,
after=last_message[0].id,
limit=100,
)
response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"]
response_message_types_from_db = [m.message_type for m in response_messages_from_db]
assert len(response_messages_from_stream) == len(response_messages_from_db)
assert response_message_types_from_stream == response_message_types_from_db
for idx in range(len(response_messages_from_stream)):
stream_message = response_messages_from_stream[idx]
db_message = response_messages_from_db[idx]
assert stream_message.message_type == db_message.message_type
assert stream_message.id == db_message.id
assert stream_message.otid == db_message.otid
if stream_message.message_type == "reasoning_message":
assert stream_message.reasoning == db_message.reasoning
if stream_message.message_type == "tool_call_message":
assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id
assert stream_message.tool_call.name == db_message.tool_call.name
if stream_message.tool_call.name == "craft_queries":
assert "queries" in stream_message.tool_call.arguments
assert "queries" in db_message.tool_call.arguments
if stream_message.tool_call.name == "search_and_store_pinecone_records":
assert "query_text" in stream_message.tool_call.arguments
assert "query_text" in db_message.tool_call.arguments
if stream_message.tool_call.name == "summarize_pinecone_results":
assert "summary" in stream_message.tool_call.arguments
assert "summary" in db_message.tool_call.arguments
assert "inner_thoughts" not in stream_message.tool_call.arguments
assert "inner_thoughts" not in db_message.tool_call.arguments
if stream_message.message_type == "tool_return_message":
assert stream_message.tool_return == db_message.tool_return
await client.agents.delete(agent_id=agent.id)