From 1f3dd068dba1a2ca9b880d81ac5acd81033bcda4 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 4 Jun 2025 17:06:32 -0700 Subject: [PATCH] feat: add pinecone integration test (#2636) --- tests/integration_test_pinecone_tool.py | 211 ++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 tests/integration_test_pinecone_tool.py diff --git a/tests/integration_test_pinecone_tool.py b/tests/integration_test_pinecone_tool.py new file mode 100644 index 00000000..9ce7589d --- /dev/null +++ b/tests/integration_test_pinecone_tool.py @@ -0,0 +1,211 @@ +import asyncio +import json +import os +import threading +import time + +import pytest +import requests +from dotenv import load_dotenv +from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage +from letta_client.core import RequestOptions + +REASONING_THROTTLE_MS = 100 +TEST_USER_MESSAGE = "What products or services does 11x AI sell?" + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="function") +def client(server_url: str): + """ + Creates and returns an asynchronous Letta REST client for testing. + """ + async_client_instance = AsyncLetta(base_url=server_url) + yield async_client_instance + + +async def test_pinecone_tool(client: AsyncLetta) -> None: + """ + Test the Pinecone tool integration with the Letta client. + """ + with open("../../scripts/test-afs/knowledge-base.af", "rb") as f: + agent = await client.agents.import_file(file=f) + + agent = await client.agents.modify( + agent_id=agent.id, + tool_exec_environment_variables={ + "PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"), + "PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"), + "PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"), + }, + ) + last_message = await client.agents.messages.list( + agent_id=agent.id, + limit=1, + ) + + curr_message_type = None + messages = [] + reasoning_content = [] + last_reasoning_update_ms = 0 + tool_call_content = "" + tool_return_content = "" + summary = None + pinecone_results = None + queries = [] + + try: + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content=TEST_USER_MESSAGE, + ), + ], + stream_tokens=True, + request_options=RequestOptions( + timeout_in_seconds=1000, + ), + ) + + async for chunk in response: + if chunk.message_type != curr_message_type: + messages.append(chunk) + curr_message_type = chunk.message_type + if curr_message_type == "reasoning_message": + reasoning_content = [] + if curr_message_type == "tool_call_message": + tool_call_content = "" + + if chunk.message_type == "reasoning_message": + now_ms = time.time_ns() // 1_000_000 + if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS: + await asyncio.sleep(REASONING_THROTTLE_MS / 1000) + + last_reasoning_update_ms = now_ms + if len(reasoning_content) == 0: + reasoning_content = [chunk.reasoning] + else: + reasoning_content[-1] += chunk.reasoning + + message_dict = messages[-1].model_dump() + message_dict["reasoning"] = "".join(reasoning_content).strip() + messages[-1] = ReasoningMessage(**message_dict) + + if chunk.message_type == "tool_return_message": + tool_return_content += chunk.tool_return + + if chunk.status == "success": + try: + if chunk.name == "summarize_pinecone_results": + json_response = json.loads(chunk.tool_return) + summary = json_response.get("summary", None) + pinecone_results = json_response.get("pinecone_results", None) + tool_return_content = "" + elif chunk.name == "craft_queries": + queries.append(chunk.tool_return) + tool_return_content = "" + except Exception as e: + print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n") + tool_return_content = "" + + if chunk.message_type == "tool_call_message": + if chunk.tool_call.arguments is not None: + tool_call_content += chunk.tool_call.arguments + message_dict = messages[-1].model_dump() + message_dict["tool_call"]["arguments"] = tool_call_content + messages[-1] = ToolCallMessage(**message_dict) + + except Exception as e: + print(f"Failed to fetch knowledge base response: {str(e)}\n") + print(tool_call_content) + raise e + + assert len(messages) > 0, "No messages received from the agent." + assert len(reasoning_content) > 0, "No reasoning content received from the agent." + assert summary is not None, "No summary received from the agent." + assert pinecone_results is not None, "No Pinecone results received from the agent." + assert len(queries) > 0, "No queries received from the agent." + + assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats." + response_messages_from_stream = [m for m in messages if m.message_type != "usage_statistics"] + response_message_types_from_stream = [m.message_type for m in response_messages_from_stream] + + messages_from_db = await client.agents.messages.list( + agent_id=agent.id, + after=last_message[0].id, + limit=100, + ) + response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"] + response_message_types_from_db = [m.message_type for m in response_messages_from_db] + + assert len(response_messages_from_stream) == len(response_messages_from_db) + assert response_message_types_from_stream == response_message_types_from_db + for idx in range(len(response_messages_from_stream)): + stream_message = response_messages_from_stream[idx] + db_message = response_messages_from_db[idx] + assert stream_message.message_type == db_message.message_type + assert stream_message.id == db_message.id + assert stream_message.otid == db_message.otid + + if stream_message.message_type == "reasoning_message": + assert stream_message.reasoning == db_message.reasoning + + if stream_message.message_type == "tool_call_message": + assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id + assert stream_message.tool_call.name == db_message.tool_call.name + + if stream_message.tool_call.name == "craft_queries": + assert "queries" in stream_message.tool_call.arguments + assert "queries" in db_message.tool_call.arguments + if stream_message.tool_call.name == "search_and_store_pinecone_records": + assert "query_text" in stream_message.tool_call.arguments + assert "query_text" in db_message.tool_call.arguments + if stream_message.tool_call.name == "summarize_pinecone_results": + assert "summary" in stream_message.tool_call.arguments + assert "summary" in db_message.tool_call.arguments + + assert "inner_thoughts" not in stream_message.tool_call.arguments + assert "inner_thoughts" not in db_message.tool_call.arguments + + if stream_message.message_type == "tool_return_message": + assert stream_message.tool_return == db_message.tool_return + + await client.agents.delete(agent_id=agent.id)