Add printed out responses for easier debugging from tests

This commit is contained in:
Matt Zhou
2024-10-04 15:19:40 -07:00
parent 776d7dd6e8
commit 6bcec854d6
6 changed files with 75 additions and 15 deletions

View File

@@ -1,7 +1,11 @@
import json
import logging
import uuid
from typing import Callable, List, Optional, Union
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from letta import LocalClient, RESTClient, create_client
from letta.agent import Agent
from letta.config import LettaConfig
@@ -26,7 +30,12 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
Choice,
FunctionCall,
Message,
)
from letta.utils import get_human_text, get_persona_text
from tests.helpers.utils import cleanup
@@ -68,7 +77,13 @@ def setup_agent(
return agent_state
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False):
# ======================================================================================================================
# Section: Complex E2E Tests
# These functions describe individual testing scenarios.
# ======================================================================================================================
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False) -> ChatCompletionResponse:
"""
Checks that the first response is valid:
@@ -110,8 +125,10 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts
# Assert that the message has an inner monologue
assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs)
return response
def check_response_contains_keyword(filename: str):
def check_response_contains_keyword(filename: str, keyword="banana") -> LettaResponse:
"""
Checks that the prompted response from the LLM contains a chosen keyword
@@ -121,7 +138,6 @@ def check_response_contains_keyword(filename: str):
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
keyword = "banana"
keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"'
response = client.user_message(agent_id=agent_state.id, message=keyword_message)
@@ -134,8 +150,10 @@ def check_response_contains_keyword(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_uses_external_tool(filename: str):
def check_agent_uses_external_tool(filename: str) -> LettaResponse:
"""
Checks that the LLM will use external tools if instructed
@@ -177,8 +195,10 @@ def check_agent_uses_external_tool(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_recall_chat_memory(filename: str):
def check_agent_recall_chat_memory(filename: str) -> LettaResponse:
"""
Checks that the LLM will recall the chat memory, specifically the human persona.
@@ -202,8 +222,10 @@ def check_agent_recall_chat_memory(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_archival_memory_retrieval(filename: str):
def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse:
"""
Checks that the LLM will execute an archival memory retrieval.
@@ -230,6 +252,8 @@ def check_agent_archival_memory_retrieval(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def run_embedding_endpoint(filename):
# load JSON file
@@ -255,7 +279,7 @@ def assert_sanity_checks(response: LettaResponse):
assert len(response.messages) > 0
def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str) -> None:
def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
# Find first instance of send_message
target_message = None
for message in messages:
@@ -280,6 +304,10 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
)
# Check that the keyword is in the message arguments
if not case_sensitive:
keyword = keyword.lower()
arguments["message"] = arguments["message"].lower()
if not keyword in arguments["message"]:
raise InvalidFunctionCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")