Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Shubham Naik <shub@letta.com>
215 lines
8.4 KiB
Python
215 lines
8.4 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import threading
|
||
import time
|
||
|
||
import pytest
|
||
import requests
|
||
from dotenv import load_dotenv
|
||
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
|
||
from letta_client.core import RequestOptions
|
||
|
||
REASONING_THROTTLE_MS = 100
|
||
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def server_url() -> str:
|
||
"""
|
||
Provides the URL for the Letta server.
|
||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||
and polls until it’s accepting connections.
|
||
"""
|
||
|
||
def _run_server() -> None:
|
||
load_dotenv()
|
||
from letta.server.rest_api.app import start_server
|
||
|
||
start_server(debug=True)
|
||
|
||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||
|
||
if not os.getenv("LETTA_SERVER_URL"):
|
||
thread = threading.Thread(target=_run_server, daemon=True)
|
||
thread.start()
|
||
|
||
# Poll until the server is up (or timeout)
|
||
timeout_seconds = 30
|
||
deadline = time.time() + timeout_seconds
|
||
while time.time() < deadline:
|
||
try:
|
||
resp = requests.get(url + "/v1/health")
|
||
if resp.status_code < 500:
|
||
break
|
||
except requests.exceptions.RequestException:
|
||
pass
|
||
time.sleep(0.1)
|
||
else:
|
||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||
|
||
return url
|
||
|
||
|
||
@pytest.fixture(scope="function")
|
||
def client(server_url: str):
|
||
"""
|
||
Creates and returns an asynchronous Letta REST client for testing.
|
||
"""
|
||
async_client_instance = AsyncLetta(base_url=server_url)
|
||
yield async_client_instance
|
||
|
||
|
||
async def test_pinecone_tool(client: AsyncLetta) -> None:
|
||
"""
|
||
Test the Pinecone tool integration with the Letta client.
|
||
"""
|
||
with open("../../scripts/test-afs/knowledge-base.af", "rb") as f:
|
||
response = await client.agents.import_file(file=f)
|
||
|
||
agent_id = response.agent_ids[0]
|
||
|
||
agent = await client.agents.modify(
|
||
agent_id=agent_id,
|
||
tool_exec_environment_variables={
|
||
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
|
||
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
|
||
"PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"),
|
||
},
|
||
)
|
||
last_message = await client.agents.messages.list(
|
||
agent_id=agent.id,
|
||
limit=1,
|
||
)
|
||
|
||
curr_message_type = None
|
||
messages = []
|
||
reasoning_content = []
|
||
last_reasoning_update_ms = 0
|
||
tool_call_content = ""
|
||
tool_return_content = ""
|
||
summary = None
|
||
pinecone_results = None
|
||
queries = []
|
||
|
||
try:
|
||
response = client.agents.messages.create_stream(
|
||
agent_id=agent.id,
|
||
messages=[
|
||
MessageCreate(
|
||
role="user",
|
||
content=TEST_USER_MESSAGE,
|
||
),
|
||
],
|
||
stream_tokens=True,
|
||
request_options=RequestOptions(
|
||
timeout_in_seconds=1000,
|
||
),
|
||
)
|
||
|
||
async for chunk in response:
|
||
if chunk.message_type != curr_message_type:
|
||
messages.append(chunk)
|
||
curr_message_type = chunk.message_type
|
||
if curr_message_type == "reasoning_message":
|
||
reasoning_content = []
|
||
if curr_message_type == "tool_call_message":
|
||
tool_call_content = ""
|
||
|
||
if chunk.message_type == "reasoning_message":
|
||
now_ms = time.time_ns() // 1_000_000
|
||
if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS:
|
||
await asyncio.sleep(REASONING_THROTTLE_MS / 1000)
|
||
|
||
last_reasoning_update_ms = now_ms
|
||
if len(reasoning_content) == 0:
|
||
reasoning_content = [chunk.reasoning]
|
||
else:
|
||
reasoning_content[-1] += chunk.reasoning
|
||
|
||
message_dict = messages[-1].model_dump()
|
||
message_dict["reasoning"] = "".join(reasoning_content).strip()
|
||
messages[-1] = ReasoningMessage(**message_dict)
|
||
|
||
if chunk.message_type == "tool_return_message":
|
||
tool_return_content += chunk.tool_return
|
||
|
||
if chunk.status == "success":
|
||
try:
|
||
if chunk.name == "summarize_pinecone_results":
|
||
json_response = json.loads(chunk.tool_return)
|
||
summary = json_response.get("summary", None)
|
||
pinecone_results = json_response.get("pinecone_results", None)
|
||
tool_return_content = ""
|
||
elif chunk.name == "craft_queries":
|
||
queries.append(chunk.tool_return)
|
||
tool_return_content = ""
|
||
except Exception as e:
|
||
print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n")
|
||
tool_return_content = ""
|
||
|
||
if chunk.message_type == "tool_call_message":
|
||
if chunk.tool_call.arguments is not None:
|
||
tool_call_content += chunk.tool_call.arguments
|
||
message_dict = messages[-1].model_dump()
|
||
message_dict["tool_call"]["arguments"] = tool_call_content
|
||
messages[-1] = ToolCallMessage(**message_dict)
|
||
|
||
except Exception as e:
|
||
print(f"Failed to fetch knowledge base response: {str(e)}\n")
|
||
print(tool_call_content)
|
||
raise e
|
||
|
||
assert len(messages) > 0, "No messages received from the agent."
|
||
assert len(reasoning_content) > 0, "No reasoning content received from the agent."
|
||
assert summary is not None, "No summary received from the agent."
|
||
assert pinecone_results is not None, "No Pinecone results received from the agent."
|
||
assert len(queries) > 0, "No queries received from the agent."
|
||
|
||
assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason."
|
||
assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats."
|
||
response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]]
|
||
response_message_types_from_stream = [m.message_type for m in response_messages_from_stream]
|
||
|
||
messages_from_db = await client.agents.messages.list(
|
||
agent_id=agent.id,
|
||
after=last_message[0].id,
|
||
limit=100,
|
||
)
|
||
response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"]
|
||
response_message_types_from_db = [m.message_type for m in response_messages_from_db]
|
||
|
||
assert len(response_messages_from_stream) == len(response_messages_from_db)
|
||
assert response_message_types_from_stream == response_message_types_from_db
|
||
for idx in range(len(response_messages_from_stream)):
|
||
stream_message = response_messages_from_stream[idx]
|
||
db_message = response_messages_from_db[idx]
|
||
assert stream_message.message_type == db_message.message_type
|
||
assert stream_message.id == db_message.id
|
||
assert stream_message.otid == db_message.otid
|
||
|
||
if stream_message.message_type == "reasoning_message":
|
||
assert stream_message.reasoning == db_message.reasoning
|
||
|
||
if stream_message.message_type == "tool_call_message":
|
||
assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id
|
||
assert stream_message.tool_call.name == db_message.tool_call.name
|
||
|
||
if stream_message.tool_call.name == "craft_queries":
|
||
assert "queries" in stream_message.tool_call.arguments
|
||
assert "queries" in db_message.tool_call.arguments
|
||
if stream_message.tool_call.name == "search_and_store_pinecone_records":
|
||
assert "query_text" in stream_message.tool_call.arguments
|
||
assert "query_text" in db_message.tool_call.arguments
|
||
if stream_message.tool_call.name == "summarize_pinecone_results":
|
||
assert "summary" in stream_message.tool_call.arguments
|
||
assert "summary" in db_message.tool_call.arguments
|
||
|
||
assert "inner_thoughts" not in stream_message.tool_call.arguments
|
||
assert "inner_thoughts" not in db_message.tool_call.arguments
|
||
|
||
if stream_message.message_type == "tool_return_message":
|
||
assert stream_message.tool_return == db_message.tool_return
|
||
|
||
await client.agents.delete(agent_id=agent.id)
|