Add printed out responses for easier debugging from tests

This commit is contained in:
Matt Zhou
2024-10-04 15:19:40 -07:00
parent 776d7dd6e8
commit 6bcec854d6
6 changed files with 75 additions and 15 deletions

View File

@@ -53,7 +53,7 @@ class BaseBlock(LettaBase, validate_assignment=True):
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
self.__class__.model_validate(self.model_dump(exclude_unset=True))
class Block(BaseBlock):

View File

@@ -1,3 +1,4 @@
import json
from typing import List, Union
from pydantic import BaseModel, Field
@@ -23,6 +24,16 @@ class LettaResponse(BaseModel):
messages: Union[List[Message], List[LettaMessage]] = Field(..., description="The messages returned by the agent.")
usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.")
def __str__(self):
return json.dumps(
{
"messages": [message.model_dump() for message in self.messages],
# Assume `Message` and `LettaMessage` have a `dict()` method
"usage": self.usage.model_dump(), # Assume `LettaUsageStatistics` has a `dict()` method
},
indent=4,
)
# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus]

View File

@@ -456,7 +456,7 @@ class SyncServer(Server):
logger.debug("Calling step_yield()")
letta_agent.interface.step_yield()
return LettaUsageStatistics(**total_usage.dict(), step_count=step_count)
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
"""Process a CLI command"""

5
tests/conftest.py Normal file
View File

@@ -0,0 +1,5 @@
import logging
def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)

View File

@@ -1,7 +1,11 @@
import json
import logging
import uuid
from typing import Callable, List, Optional, Union
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from letta import LocalClient, RESTClient, create_client
from letta.agent import Agent
from letta.config import LettaConfig
@@ -26,7 +30,12 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
Choice,
FunctionCall,
Message,
)
from letta.utils import get_human_text, get_persona_text
from tests.helpers.utils import cleanup
@@ -68,7 +77,13 @@ def setup_agent(
return agent_state
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False):
# ======================================================================================================================
# Section: Complex E2E Tests
# These functions describe individual testing scenarios.
# ======================================================================================================================
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False) -> ChatCompletionResponse:
"""
Checks that the first response is valid:
@@ -110,8 +125,10 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts
# Assert that the message has an inner monologue
assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs)
return response
def check_response_contains_keyword(filename: str):
def check_response_contains_keyword(filename: str, keyword="banana") -> LettaResponse:
"""
Checks that the prompted response from the LLM contains a chosen keyword
@@ -121,7 +138,6 @@ def check_response_contains_keyword(filename: str):
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
keyword = "banana"
keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"'
response = client.user_message(agent_id=agent_state.id, message=keyword_message)
@@ -134,8 +150,10 @@ def check_response_contains_keyword(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_uses_external_tool(filename: str):
def check_agent_uses_external_tool(filename: str) -> LettaResponse:
"""
Checks that the LLM will use external tools if instructed
@@ -177,8 +195,10 @@ def check_agent_uses_external_tool(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_recall_chat_memory(filename: str):
def check_agent_recall_chat_memory(filename: str) -> LettaResponse:
"""
Checks that the LLM will recall the chat memory, specifically the human persona.
@@ -202,8 +222,10 @@ def check_agent_recall_chat_memory(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_archival_memory_retrieval(filename: str):
def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse:
"""
Checks that the LLM will execute an archival memory retrieval.
@@ -230,6 +252,8 @@ def check_agent_archival_memory_retrieval(filename: str):
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def run_embedding_endpoint(filename):
# load JSON file
@@ -255,7 +279,7 @@ def assert_sanity_checks(response: LettaResponse):
assert len(response.messages) > 0
def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str) -> None:
def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
# Find first instance of send_message
target_message = None
for message in messages:
@@ -280,6 +304,10 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
)
# Check that the keyword is in the message arguments
if not case_sensitive:
keyword = keyword.lower()
arguments["message"] = arguments["message"].lower()
if not keyword in arguments["message"]:
raise InvalidFunctionCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")

View File

@@ -19,27 +19,43 @@ llm_config_dir = "configs/llm_model_configs"
# ======================================================================================================================
def test_openai_gpt_4_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_first_response_is_valid_for_llm_endpoint(filename)
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_openai_gpt_4_returns_keyword():
keyword = "banana"
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_response_contains_keyword(filename)
response = check_response_contains_keyword(filename, keyword=keyword)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_openai_gpt_4_uses_external_tool():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_agent_uses_external_tool(filename)
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_openai_gpt_4_recall_chat_memory():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_agent_recall_chat_memory(filename)
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_openai_gpt_4_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_agent_archival_memory_retrieval(filename)
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_embedding_endpoint_openai():