216 lines
8.5 KiB
Python
216 lines
8.5 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
from dotenv import load_dotenv
|
|
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
|
|
from letta_client.core import RequestOptions
|
|
|
|
from tests.helpers.utils import upload_test_agentfile_from_disk
|
|
|
|
REASONING_THROTTLE_MS = 100
|
|
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server_url() -> str:
|
|
"""
|
|
Provides the URL for the Letta server.
|
|
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
|
and polls until it's accepting connections.
|
|
"""
|
|
|
|
def _run_server() -> None:
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
thread = threading.Thread(target=_run_server, daemon=True)
|
|
thread.start()
|
|
|
|
# Poll until the server is up (or timeout)
|
|
timeout_seconds = 30
|
|
deadline = time.time() + timeout_seconds
|
|
while time.time() < deadline:
|
|
try:
|
|
resp = requests.get(url + "/v1/health")
|
|
if resp.status_code < 500:
|
|
break
|
|
except requests.exceptions.RequestException:
|
|
pass
|
|
time.sleep(0.1)
|
|
else:
|
|
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
|
|
|
return url
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client(server_url: str):
|
|
"""
|
|
Creates and returns an asynchronous Letta REST client for testing.
|
|
"""
|
|
async_client_instance = AsyncLetta(base_url=server_url)
|
|
yield async_client_instance
|
|
|
|
|
|
async def test_pinecone_tool(client: AsyncLetta, server_url: str) -> None:
|
|
"""
|
|
Test the Pinecone tool integration with the Letta client.
|
|
"""
|
|
response = upload_test_agentfile_from_disk(server_url, "knowledge-base.af")
|
|
|
|
agent_id = response.agent_ids[0]
|
|
|
|
agent = await client.agents.modify(
|
|
agent_id=agent_id,
|
|
tool_exec_environment_variables={
|
|
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
|
|
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
|
|
"PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"),
|
|
},
|
|
)
|
|
last_message = await client.agents.messages.list(
|
|
agent_id=agent.id,
|
|
limit=1,
|
|
)
|
|
|
|
curr_message_type = None
|
|
messages = []
|
|
reasoning_content = []
|
|
last_reasoning_update_ms = 0
|
|
tool_call_content = ""
|
|
tool_return_content = ""
|
|
summary = None
|
|
pinecone_results = None
|
|
queries = []
|
|
|
|
try:
|
|
response = client.agents.messages.create_stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content=TEST_USER_MESSAGE,
|
|
),
|
|
],
|
|
stream_tokens=True,
|
|
request_options=RequestOptions(
|
|
timeout_in_seconds=1000,
|
|
),
|
|
)
|
|
|
|
async for chunk in response:
|
|
if chunk.message_type != curr_message_type:
|
|
messages.append(chunk)
|
|
curr_message_type = chunk.message_type
|
|
if curr_message_type == "reasoning_message":
|
|
reasoning_content = []
|
|
if curr_message_type == "tool_call_message":
|
|
tool_call_content = ""
|
|
|
|
if chunk.message_type == "reasoning_message":
|
|
now_ms = time.time_ns() // 1_000_000
|
|
if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS:
|
|
await asyncio.sleep(REASONING_THROTTLE_MS / 1000)
|
|
|
|
last_reasoning_update_ms = now_ms
|
|
if len(reasoning_content) == 0:
|
|
reasoning_content = [chunk.reasoning]
|
|
else:
|
|
reasoning_content[-1] += chunk.reasoning
|
|
|
|
message_dict = messages[-1].model_dump()
|
|
message_dict["reasoning"] = "".join(reasoning_content).strip()
|
|
messages[-1] = ReasoningMessage(**message_dict)
|
|
|
|
if chunk.message_type == "tool_return_message":
|
|
tool_return_content += chunk.tool_return
|
|
|
|
if chunk.status == "success":
|
|
try:
|
|
if chunk.name == "summarize_pinecone_results":
|
|
json_response = json.loads(chunk.tool_return)
|
|
summary = json_response.get("summary", None)
|
|
pinecone_results = json_response.get("pinecone_results", None)
|
|
tool_return_content = ""
|
|
elif chunk.name == "craft_queries":
|
|
queries.append(chunk.tool_return)
|
|
tool_return_content = ""
|
|
except Exception as e:
|
|
print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n")
|
|
tool_return_content = ""
|
|
|
|
if chunk.message_type == "tool_call_message":
|
|
if chunk.tool_call.arguments is not None:
|
|
tool_call_content += chunk.tool_call.arguments
|
|
message_dict = messages[-1].model_dump()
|
|
message_dict["tool_call"]["arguments"] = tool_call_content
|
|
messages[-1] = ToolCallMessage(**message_dict)
|
|
|
|
except Exception as e:
|
|
print(f"Failed to fetch knowledge base response: {str(e)}\n")
|
|
print(tool_call_content)
|
|
raise e
|
|
|
|
assert len(messages) > 0, "No messages received from the agent."
|
|
assert len(reasoning_content) > 0, "No reasoning content received from the agent."
|
|
assert summary is not None, "No summary received from the agent."
|
|
assert pinecone_results is not None, "No Pinecone results received from the agent."
|
|
assert len(queries) > 0, "No queries received from the agent."
|
|
|
|
assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason."
|
|
assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats."
|
|
response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]]
|
|
response_message_types_from_stream = [m.message_type for m in response_messages_from_stream]
|
|
|
|
messages_from_db = await client.agents.messages.list(
|
|
agent_id=agent.id,
|
|
after=last_message[0].id,
|
|
limit=100,
|
|
)
|
|
response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"]
|
|
response_message_types_from_db = [m.message_type for m in response_messages_from_db]
|
|
|
|
assert len(response_messages_from_stream) == len(response_messages_from_db)
|
|
assert response_message_types_from_stream == response_message_types_from_db
|
|
for idx in range(len(response_messages_from_stream)):
|
|
stream_message = response_messages_from_stream[idx]
|
|
db_message = response_messages_from_db[idx]
|
|
assert stream_message.message_type == db_message.message_type
|
|
assert stream_message.id == db_message.id
|
|
assert stream_message.otid == db_message.otid
|
|
|
|
if stream_message.message_type == "reasoning_message":
|
|
assert stream_message.reasoning == db_message.reasoning
|
|
|
|
if stream_message.message_type == "tool_call_message":
|
|
assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id
|
|
assert stream_message.tool_call.name == db_message.tool_call.name
|
|
|
|
if stream_message.tool_call.name == "craft_queries":
|
|
assert "queries" in stream_message.tool_call.arguments
|
|
assert "queries" in db_message.tool_call.arguments
|
|
if stream_message.tool_call.name == "search_and_store_pinecone_records":
|
|
assert "query_text" in stream_message.tool_call.arguments
|
|
assert "query_text" in db_message.tool_call.arguments
|
|
if stream_message.tool_call.name == "summarize_pinecone_results":
|
|
assert "summary" in stream_message.tool_call.arguments
|
|
assert "summary" in db_message.tool_call.arguments
|
|
|
|
assert "inner_thoughts" not in stream_message.tool_call.arguments
|
|
assert "inner_thoughts" not in db_message.tool_call.arguments
|
|
|
|
if stream_message.message_type == "tool_return_message":
|
|
assert stream_message.tool_return == db_message.tool_return
|
|
|
|
await client.agents.delete(agent_id=agent.id)
|