feat: add pinecone integration test (#2636)
This commit is contained in:
211
tests/integration_test_pinecone_tool.py
Normal file
211
tests/integration_test_pinecone_tool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
|
||||
from letta_client.core import RequestOptions
|
||||
|
||||
REASONING_THROTTLE_MS = 100
|
||||
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it’s accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(server_url: str):
|
||||
"""
|
||||
Creates and returns an asynchronous Letta REST client for testing.
|
||||
"""
|
||||
async_client_instance = AsyncLetta(base_url=server_url)
|
||||
yield async_client_instance
|
||||
|
||||
|
||||
async def test_pinecone_tool(client: AsyncLetta) -> None:
|
||||
"""
|
||||
Test the Pinecone tool integration with the Letta client.
|
||||
"""
|
||||
with open("../../scripts/test-afs/knowledge-base.af", "rb") as f:
|
||||
agent = await client.agents.import_file(file=f)
|
||||
|
||||
agent = await client.agents.modify(
|
||||
agent_id=agent.id,
|
||||
tool_exec_environment_variables={
|
||||
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
|
||||
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
|
||||
"PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"),
|
||||
},
|
||||
)
|
||||
last_message = await client.agents.messages.list(
|
||||
agent_id=agent.id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
curr_message_type = None
|
||||
messages = []
|
||||
reasoning_content = []
|
||||
last_reasoning_update_ms = 0
|
||||
tool_call_content = ""
|
||||
tool_return_content = ""
|
||||
summary = None
|
||||
pinecone_results = None
|
||||
queries = []
|
||||
|
||||
try:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=TEST_USER_MESSAGE,
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
request_options=RequestOptions(
|
||||
timeout_in_seconds=1000,
|
||||
),
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.message_type != curr_message_type:
|
||||
messages.append(chunk)
|
||||
curr_message_type = chunk.message_type
|
||||
if curr_message_type == "reasoning_message":
|
||||
reasoning_content = []
|
||||
if curr_message_type == "tool_call_message":
|
||||
tool_call_content = ""
|
||||
|
||||
if chunk.message_type == "reasoning_message":
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS:
|
||||
await asyncio.sleep(REASONING_THROTTLE_MS / 1000)
|
||||
|
||||
last_reasoning_update_ms = now_ms
|
||||
if len(reasoning_content) == 0:
|
||||
reasoning_content = [chunk.reasoning]
|
||||
else:
|
||||
reasoning_content[-1] += chunk.reasoning
|
||||
|
||||
message_dict = messages[-1].model_dump()
|
||||
message_dict["reasoning"] = "".join(reasoning_content).strip()
|
||||
messages[-1] = ReasoningMessage(**message_dict)
|
||||
|
||||
if chunk.message_type == "tool_return_message":
|
||||
tool_return_content += chunk.tool_return
|
||||
|
||||
if chunk.status == "success":
|
||||
try:
|
||||
if chunk.name == "summarize_pinecone_results":
|
||||
json_response = json.loads(chunk.tool_return)
|
||||
summary = json_response.get("summary", None)
|
||||
pinecone_results = json_response.get("pinecone_results", None)
|
||||
tool_return_content = ""
|
||||
elif chunk.name == "craft_queries":
|
||||
queries.append(chunk.tool_return)
|
||||
tool_return_content = ""
|
||||
except Exception as e:
|
||||
print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n")
|
||||
tool_return_content = ""
|
||||
|
||||
if chunk.message_type == "tool_call_message":
|
||||
if chunk.tool_call.arguments is not None:
|
||||
tool_call_content += chunk.tool_call.arguments
|
||||
message_dict = messages[-1].model_dump()
|
||||
message_dict["tool_call"]["arguments"] = tool_call_content
|
||||
messages[-1] = ToolCallMessage(**message_dict)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch knowledge base response: {str(e)}\n")
|
||||
print(tool_call_content)
|
||||
raise e
|
||||
|
||||
assert len(messages) > 0, "No messages received from the agent."
|
||||
assert len(reasoning_content) > 0, "No reasoning content received from the agent."
|
||||
assert summary is not None, "No summary received from the agent."
|
||||
assert pinecone_results is not None, "No Pinecone results received from the agent."
|
||||
assert len(queries) > 0, "No queries received from the agent."
|
||||
|
||||
assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats."
|
||||
response_messages_from_stream = [m for m in messages if m.message_type != "usage_statistics"]
|
||||
response_message_types_from_stream = [m.message_type for m in response_messages_from_stream]
|
||||
|
||||
messages_from_db = await client.agents.messages.list(
|
||||
agent_id=agent.id,
|
||||
after=last_message[0].id,
|
||||
limit=100,
|
||||
)
|
||||
response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"]
|
||||
response_message_types_from_db = [m.message_type for m in response_messages_from_db]
|
||||
|
||||
assert len(response_messages_from_stream) == len(response_messages_from_db)
|
||||
assert response_message_types_from_stream == response_message_types_from_db
|
||||
for idx in range(len(response_messages_from_stream)):
|
||||
stream_message = response_messages_from_stream[idx]
|
||||
db_message = response_messages_from_db[idx]
|
||||
assert stream_message.message_type == db_message.message_type
|
||||
assert stream_message.id == db_message.id
|
||||
assert stream_message.otid == db_message.otid
|
||||
|
||||
if stream_message.message_type == "reasoning_message":
|
||||
assert stream_message.reasoning == db_message.reasoning
|
||||
|
||||
if stream_message.message_type == "tool_call_message":
|
||||
assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id
|
||||
assert stream_message.tool_call.name == db_message.tool_call.name
|
||||
|
||||
if stream_message.tool_call.name == "craft_queries":
|
||||
assert "queries" in stream_message.tool_call.arguments
|
||||
assert "queries" in db_message.tool_call.arguments
|
||||
if stream_message.tool_call.name == "search_and_store_pinecone_records":
|
||||
assert "query_text" in stream_message.tool_call.arguments
|
||||
assert "query_text" in db_message.tool_call.arguments
|
||||
if stream_message.tool_call.name == "summarize_pinecone_results":
|
||||
assert "summary" in stream_message.tool_call.arguments
|
||||
assert "summary" in db_message.tool_call.arguments
|
||||
|
||||
assert "inner_thoughts" not in stream_message.tool_call.arguments
|
||||
assert "inner_thoughts" not in db_message.tool_call.arguments
|
||||
|
||||
if stream_message.message_type == "tool_return_message":
|
||||
assert stream_message.tool_return == db_message.tool_return
|
||||
|
||||
await client.agents.delete(agent_id=agent.id)
|
||||
Reference in New Issue
Block a user