Finish testing gpt4 openai

This commit is contained in:
Matt Zhou
2024-10-03 18:28:58 -07:00
parent cd84d9fbdd
commit ab5d12f586
5 changed files with 364 additions and 84 deletions

View File

@@ -30,11 +30,35 @@ jobs:
run: |
poetry run letta quickstart --backend openai
- name: Test LLM endpoint
- name: Test first message contains expected function call and inner monologue
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_openai
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_valid_first_message
- name: Test model sends message with keyword
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_keyword
- name: Test model uses external tool correctly
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_uses_external_tool
- name: Test model recalls chat memory
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_recall_chat_memory
- name: Test model uses `archival_memory_search` to find secret
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_archival_memory_retrieval
- name: Test embedding endpoint
env:

View File

@@ -1592,7 +1592,7 @@ class LocalClient(AbstractClient):
# memory
def get_in_context_memory(self, agent_id: str) -> Memory:
"""
Get the in-contxt (i.e. core) memory of an agent
Get the in-context (i.e. core) memory of an agent
Args:
agent_id (str): ID of the agent

View File

@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
import json
from typing import TYPE_CHECKING, List, Optional, Union
# Avoid circular imports
if TYPE_CHECKING:
@@ -37,73 +38,47 @@ class LocalLLMConnectionError(LettaError):
super().__init__(self.message)
class MissingFunctionCallError(LettaError):
message: "Message"
""" The message that caused this error.
class LettaMessageError(LettaError):
"""Base error class for handling message-related errors."""
This error should be raised when a message that we expect to have a function call does not.
"""
def __init__(self, *, message: "Message") -> None:
error_msg = "The message is missing a function call: \n\n"
# Pretty print out message
message_json = message.model_dump_json(indent=4)
error_msg += f"{message_json}"
messages: List[Union["Message", "LettaMessage"]]
default_error_message: str = "An error occurred with the message."
def __init__(self, *, messages: List[Union["Message", "LettaMessage"]], explanation: Optional[str] = None) -> None:
error_msg = self.construct_error_message(messages, self.default_error_message, explanation)
super().__init__(error_msg)
self.message = message
self.messages = messages
@staticmethod
def construct_error_message(messages: List[Union["Message", "LettaMessage"]], error_msg: str, explanation: Optional[str] = None) -> str:
"""Helper method to construct a clean and formatted error message."""
if explanation:
error_msg += f" (Explanation: {explanation})"
# Pretty print out message JSON
message_json = json.dumps([message.model_dump_json(indent=4) for message in messages], indent=4)
return f"{error_msg}\n\n{message_json}"
class InvalidFunctionCallError(LettaError):
message: "Message"
""" The message that caused this error.
class MissingFunctionCallError(LettaMessageError):
"""Error raised when a message is missing a function call."""
This error should be raised when a message uses a function that is unexpected or invalid, or if the usage is incorrect.
"""
def __init__(self, *, message: "Message") -> None:
error_msg = "The message uses an invalid function call or has improper usage of a function call: \n\n"
# Pretty print out message
message_json = message.model_dump_json(indent=4)
error_msg += f"{message_json}"
super().__init__(error_msg)
self.message = message
default_error_message = "The message is missing a function call."
class MissingInnerMonologueError(LettaError):
message: "Message"
""" The message that caused this error.
class InvalidFunctionCallError(LettaMessageError):
"""Error raised when a message uses an invalid function call."""
This error should be raised when a message that we expect to have an inner monologue does not.
"""
def __init__(self, *, message: "Message") -> None:
error_msg = "The message is missing an inner monologue: \n\n"
# Pretty print out message
message_json = message.model_dump_json(indent=4)
error_msg += f"{message_json}"
super().__init__(error_msg)
self.message = message
default_error_message = "The message uses an invalid function call or has improper usage of a function call."
class InvalidInnerMonologueError(LettaError):
message: "Message"
""" The message that caused this error.
class MissingInnerMonologueError(LettaMessageError):
"""Error raised when a message is missing an inner monologue."""
This error should be raised when a message has an improperly formatted inner monologue.
"""
default_error_message = "The message is missing an inner monologue."
def __init__(self, *, message: "Message") -> None:
error_msg = "The message has a malformed inner monologue: \n\n"
# Pretty print out message
message_json = message.model_dump_json(indent=4)
error_msg += f"{message_json}"
class InvalidInnerMonologueError(LettaMessageError):
"""Error raised when a message has a malformed inner monologue."""
super().__init__(error_msg)
self.message = message
default_error_message = "The message has a malformed inner monologue."

View File

@@ -1,21 +1,51 @@
import json
from typing import Callable, Optional
import uuid
from typing import Callable, List, Optional, Union
from letta import LocalClient, RESTClient
from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.errors import (
InvalidFunctionCallError,
InvalidInnerMonologueError,
LettaError,
MissingFunctionCallError,
MissingInnerMonologueError,
)
from letta.llm_api.llm_api_tools import unpack_inner_thoughts_from_kwargs
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import (
FunctionCallMessage,
InternalMonologue,
LettaMessage,
)
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message
from letta.utils import get_human_text, get_persona_text
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent"))
def setup_llm_endpoint(filename: str, embedding_config_path: str) -> [LLMConfig, EmbeddingConfig]:
# ======================================================================================================================
# Section: Test Setup
# These functions help setup the test
# ======================================================================================================================
def setup_agent(
client: Union[LocalClient, RESTClient],
filename: str,
embedding_config_path: str,
memory_human_str: str = get_human_text(DEFAULT_HUMAN),
memory_persona_str: str = get_persona_text(DEFAULT_PERSONA),
tools: Optional[List[str]] = None,
) -> AgentState:
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
embedding_config = EmbeddingConfig(**json.load(open(embedding_config_path)))
@@ -26,10 +56,84 @@ def setup_llm_endpoint(filename: str, embedding_config_path: str) -> [LLMConfig,
config.default_embedding_config = embedding_config
config.save()
return llm_config, embedding_config
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools)
return agent_state
def assert_contains_valid_function_call(message: Message, function_call_validator: Optional[Callable[[FunctionCall], bool]] = None) -> None:
# ======================================================================================================================
# Section: Letta Message Assertions
# These functions are validating elements of parsed Letta Messsage
# ======================================================================================================================
def assert_sanity_checks(response: LettaResponse):
assert response is not None
assert response.messages is not None
assert len(response.messages) > 0
def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str) -> None:
# Find first instance of send_message
target_message = None
for message in messages:
if isinstance(message, FunctionCallMessage) and message.function_call.name == "send_message":
target_message = message
break
# No messages found with `send_messages`
if target_message is None:
raise LettaError("Missing send_message function call")
send_message_function_call = target_message.function_call
try:
arguments = json.loads(send_message_function_call.arguments)
except:
raise InvalidFunctionCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON")
# Message field not in send_message
if "message" not in arguments:
raise InvalidFunctionCallError(
messages=[target_message], explanation=f"send_message function call does not have required field `message`"
)
# Check that the keyword is in the message arguments
if not keyword in arguments["message"]:
raise InvalidFunctionCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None:
for message in messages:
if isinstance(message, FunctionCallMessage) and message.function_call.name == function_name:
# Found it, do nothing
return
raise MissingFunctionCallError(
messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}"
)
def assert_inner_monologue_is_present_and_valid(messages: List[LettaMessage]) -> None:
for message in messages:
if isinstance(message, InternalMonologue):
# Found it, do nothing
return
raise MissingInnerMonologueError(messages=messages)
# ======================================================================================================================
# Section: Raw API Assertions
# These functions are validating elements of the (close to) raw LLM API's response
# ======================================================================================================================
def assert_contains_valid_function_call(
message: Message,
function_call_validator: Optional[Callable[[FunctionCall], bool]] = None,
validation_failure_summary: Optional[str] = None,
) -> None:
"""
Helper function to check that a message contains a valid function call.
@@ -39,33 +143,50 @@ def assert_contains_valid_function_call(message: Message, function_call_validato
if (hasattr(message, "function_call") and message.function_call is not None) and (
hasattr(message, "tool_calls") and message.tool_calls is not None
):
return False
raise InvalidFunctionCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message")
elif hasattr(message, "function_call") and message.function_call is not None:
function_call = message.function_call
elif hasattr(message, "tool_calls") and message.tool_calls is not None:
# Note: We only take the first one for now. Is this a problem? @charles
# This seems to be standard across the repo
function_call = message.tool_calls[0].function
else:
# Throw a missing function call error
raise MissingFunctionCallError(message=message)
raise MissingFunctionCallError(messages=[message])
if function_call_validator and not function_call_validator(function_call):
raise InvalidFunctionCallError(message=message)
raise InvalidFunctionCallError(messages=[message], explanation=validation_failure_summary)
def inner_monologue_is_valid(monologue: str) -> bool:
def assert_inner_monologue_is_valid(message: Message) -> None:
"""
Helper function to check that the inner monologue is valid.
"""
invalid_chars = '(){}[]"'
# Sometimes the syntax won't be correct and internal syntax will leak into message
invalid_phrases = ["functions", "send_message"]
return any(char in monologue for char in invalid_chars) or any(p in monologue for p in invalid_phrases)
monologue = message.content
for char in invalid_chars:
if char in monologue:
raise InvalidInnerMonologueError(messages=[message], explanation=f"{char} is in monologue")
for phrase in invalid_phrases:
if phrase in monologue:
raise InvalidInnerMonologueError(messages=[message], explanation=f"{phrase} is in monologue")
def assert_contains_correct_inner_monologue(choice: Choice, inner_thoughts_in_kwargs: bool) -> None:
"""
Helper function to check that the inner monologue exists and is valid.
"""
# Unpack inner thoughts out of function kwargs, and repackage into choice
if inner_thoughts_in_kwargs:
choice = unpack_inner_thoughts_from_kwargs(choice, INNER_THOUGHTS_KWARG)
monologue = choice.message.content
message = choice.message
monologue = message.content
if not monologue or monologue is None or monologue == "":
raise MissingInnerMonologueError(message=choice.message)
elif not inner_monologue_is_valid(monologue):
raise InvalidInnerMonologueError(message=choice.message)
raise MissingInnerMonologueError(messages=[message])
assert_inner_monologue_is_valid(message)

View File

@@ -10,9 +10,14 @@ from letta.prompts import gpt_system
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.message import Message
from tests.helpers.endpoints_helper import (
agent_uuid,
assert_contains_correct_inner_monologue,
assert_contains_valid_function_call,
setup_llm_endpoint,
assert_inner_monologue_is_present_and_valid,
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
assert_sanity_checks,
setup_agent,
)
from tests.helpers.utils import cleanup
@@ -26,17 +31,21 @@ llm_config_path = "configs/llm_model_configs/letta-hosted.json"
embedding_config_dir = "configs/embedding_model_configs"
llm_config_dir = "configs/llm_model_configs"
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent"))
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False):
llm_config, embedding_config = setup_llm_endpoint(filename, embedding_config_path)
"""
Checks that the first response is valid:
1. Contains either send_message or archival_memory_search
2. Contains valid usage of the function
3. Contains inner monologue
Note: This is acting on the raw LLM response, note the usage of `create`
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config)
agent_state = setup_agent(client, filename, embedding_config_path)
tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools]
agent = Agent(
interface=None,
@@ -45,9 +54,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts
)
response = create(
llm_config=llm_config,
user_id=uuid.UUID(int=1), # dummy user_id
# messages=agent_state.messages,
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=agent._messages,
functions=agent.functions,
functions_python=agent.functions_python,
@@ -63,10 +71,130 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts
validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search"
assert_contains_valid_function_call(choice.message, validator_func)
# Assert that the choice has an inner monologue
# Assert that the message has an inner monologue
assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs)
def check_response_contains_keyword(filename: str):
"""
Checks that the prompted response from the LLM contains a chosen keyword
Note: This is acting on the Letta response, note the usage of `user_message`
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename, embedding_config_path)
keyword = "banana"
keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"'
response = client.user_message(agent_id=agent_state.id, message=keyword_message)
# Basic checks
assert_sanity_checks(response)
# Make sure the message was sent
assert_invoked_send_message_with_keyword(response.messages, keyword)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
def check_agent_uses_external_tool(filename: str):
"""
Checks that the LLM will use external tools if instructed
Note: This is acting on the Letta response, note the usage of `user_message`
"""
from crewai_tools import ScrapeWebsiteTool
from letta.schemas.tool import Tool
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
tool = Tool.from_crewai(crewai_tool)
tool_name = tool.name
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
client.add_tool(tool)
# Set up persona for tool usage
persona = f"""
My name is Letta.
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question.
Dont forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
"""
agent_state = setup_agent(client, filename, embedding_config_path, memory_persona_str=persona, tools=[tool_name])
response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?")
# Basic checks
assert_sanity_checks(response)
# Make sure the tool was called
assert_invoked_function_call(response.messages, tool_name)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
def check_agent_recall_chat_memory(filename: str):
"""
Checks that the LLM will recall the chat memory, specifically the human persona.
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
human_name = "BananaBoy"
agent_state = setup_agent(client, filename, embedding_config_path, memory_human_str=f"My name is {human_name}")
response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.")
# Basic checks
assert_sanity_checks(response)
# Make sure my name was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, human_name)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
def check_agent_archival_memory_retrieval(filename: str):
"""
Checks that the LLM will execute an archival memory retrieval.
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename, embedding_config_path)
secret_word = "banana"
client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!")
response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.")
# Basic checks
assert_sanity_checks(response)
# Make sure archival_memory_search was called
assert_invoked_function_call(response.messages, "archival_memory_search")
# Make sure secret was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, secret_word)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
def run_embedding_endpoint(filename):
# load JSON file
config_data = json.load(open(filename, "r"))
@@ -79,16 +207,42 @@ def run_embedding_endpoint(filename):
assert query_vec is not None
def test_llm_endpoint_openai():
# ======================================================================================================================
# OPENAI TESTS
# ======================================================================================================================
def test_openai_gpt_4_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_first_response_is_valid_for_llm_endpoint(filename)
def test_openai_gpt_4_returns_keyword():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_response_contains_keyword(filename)
def test_openai_gpt_4_uses_external_tool():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_agent_uses_external_tool(filename)
def test_openai_gpt_4_recall_chat_memory():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_agent_recall_chat_memory(filename)
def test_openai_gpt_4_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "gpt-4.json")
check_agent_archival_memory_retrieval(filename)
def test_embedding_endpoint_openai():
filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json")
run_embedding_endpoint(filename)
# ======================================================================================================================
# LETTA HOSTED
# ======================================================================================================================
def test_llm_endpoint_letta_hosted():
filename = os.path.join(llm_config_dir, "letta-hosted.json")
check_first_response_is_valid_for_llm_endpoint(filename)
@@ -99,6 +253,9 @@ def test_embedding_endpoint_letta_hosted():
run_embedding_endpoint(filename)
# ======================================================================================================================
# LOCAL MODELS
# ======================================================================================================================
def test_embedding_endpoint_local():
filename = os.path.join(embedding_config_dir, "local.json")
run_embedding_endpoint(filename)
@@ -114,6 +271,9 @@ def test_embedding_endpoint_ollama():
run_embedding_endpoint(filename)
# ======================================================================================================================
# ANTHROPIC TESTS
# ======================================================================================================================
def test_llm_endpoint_anthropic():
filename = os.path.join(llm_config_dir, "anthropic.json")
check_first_response_is_valid_for_llm_endpoint(filename)