feat: add testing for LLM + embedding endpoints (#1308)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
6
configs/embedding_model_configs/memgpt-hosted.json
Normal file
6
configs/embedding_model_configs/memgpt-hosted.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"embedding_endpoint": "https://embeddings.memgpt.ai",
|
||||
"embedding_model": "BAAI/bge-large-en-v1.5",
|
||||
"embedding_dim": 1024,
|
||||
"embedding_chunk_size": 300
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"embedding_endpoint_type": "openai",
|
||||
"embedding_endpoint": "https://api.openai.com/v1",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"embedding_dim": 1536,
|
||||
"embedding_chunk_size": 300
|
||||
}
|
||||
7
configs/llm_model_configs/gpt-4.json
Normal file
7
configs/llm_model_configs/gpt-4.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null
|
||||
}
|
||||
6
configs/llm_model_configs/memgpt-hosted.json
Normal file
6
configs/llm_model_configs/memgpt-hosted.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"context_window": 16384,
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://inference.memgpt.ai",
|
||||
"model": "memgpt-openai"
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
"context_window": 16384,
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://inference.memgpt.ai",
|
||||
"model": "memgpt-openai",
|
||||
"embedding_endpoint_type": "hugging-face",
|
||||
"embedding_endpoint": "https://embeddings.memgpt.ai",
|
||||
"embedding_model": "BAAI/bge-large-en-v1.5",
|
||||
|
||||
@@ -78,10 +78,10 @@ class Message(Record):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
role: str,
|
||||
text: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
agent_id: Optional[uuid.UUID] = None,
|
||||
model: Optional[str] = None, # model used to make function call
|
||||
name: Optional[str] = None, # optional participant name
|
||||
created_at: Optional[datetime] = None,
|
||||
|
||||
@@ -363,7 +363,8 @@ def openai_chat_completions_request(
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
printd(f"response = {response}")
|
||||
# printd(f"response = {response}, response.text = {response.text}")
|
||||
print(f"response = {response}, response.text = {response.text}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
|
||||
response = response.json() # convert to dict from string
|
||||
|
||||
@@ -42,6 +42,7 @@ class ToolMessage(BaseModel):
|
||||
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
|
||||
|
||||
|
||||
# TODO: this might not be necessary with the validator
|
||||
def cast_message_to_subtype(m_dict: dict) -> ChatMessage:
|
||||
"""Cast a dictionary to one of the individual message types"""
|
||||
role = m_dict.get("role")
|
||||
|
||||
@@ -65,6 +65,25 @@ def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: Me
|
||||
return preset
|
||||
|
||||
|
||||
def load_preset(preset_name: str, user_id: uuid.UUID):
|
||||
preset_config = available_presets[preset_name]
|
||||
preset_system_prompt = preset_config["system_prompt"]
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
preset = Preset(
|
||||
user_id=user_id,
|
||||
name=preset_name,
|
||||
system=gpt_system.get_system_text(preset_system_prompt),
|
||||
persona=get_persona_text(DEFAULT_PERSONA),
|
||||
persona_name=DEFAULT_PERSONA,
|
||||
human=get_human_text(DEFAULT_HUMAN),
|
||||
human_name=DEFAULT_HUMAN,
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
return preset
|
||||
|
||||
|
||||
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
"""Add the default presets to the metadata store"""
|
||||
# make sure humans/personas added
|
||||
@@ -72,25 +91,26 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
|
||||
# add default presets
|
||||
for preset_name in preset_options:
|
||||
preset_config = available_presets[preset_name]
|
||||
preset_system_prompt = preset_config["system_prompt"]
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
# preset_config = available_presets[preset_name]
|
||||
# preset_system_prompt = preset_config["system_prompt"]
|
||||
# preset_function_set_names = preset_config["functions"]
|
||||
# functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
if ms.get_preset(user_id=user_id, name=preset_name) is not None:
|
||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
|
||||
preset = Preset(
|
||||
user_id=user_id,
|
||||
name=preset_name,
|
||||
system=gpt_system.get_system_text(preset_system_prompt),
|
||||
persona=get_persona_text(DEFAULT_PERSONA),
|
||||
persona_name=DEFAULT_PERSONA,
|
||||
human=get_human_text(DEFAULT_HUMAN),
|
||||
human_name=DEFAULT_HUMAN,
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
preset = load_preset(preset_name, user_id)
|
||||
# preset = Preset(
|
||||
# user_id=user_id,
|
||||
# name=preset_name,
|
||||
# system=gpt_system.get_system_text(preset_system_prompt),
|
||||
# persona=get_persona_text(DEFAULT_PERSONA),
|
||||
# persona_name=DEFAULT_PERSONA,
|
||||
# human=get_human_text(DEFAULT_HUMAN),
|
||||
# human_name=DEFAULT_HUMAN,
|
||||
# functions_schema=functions_schema,
|
||||
# )
|
||||
ms.create_preset(preset)
|
||||
|
||||
|
||||
|
||||
70
tests/test_endpoints.py
Normal file
70
tests/test_endpoints.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.llm_api.llm_api_tools import create
|
||||
from memgpt.models.pydantic_models import EmbeddingConfigModel, LLMConfigModel
|
||||
from memgpt.presets.presets import load_preset
|
||||
from memgpt.prompts import gpt_system
|
||||
|
||||
messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")]
|
||||
|
||||
embedding_config_path = "configs/embedding_model_configs/memgpt-hosted.json"
|
||||
llm_config_path = "configs/llm_model_configs/memgpt-hosted.json"
|
||||
|
||||
|
||||
def test_embedding_endpoints():
|
||||
|
||||
embedding_config_dir = "configs/embedding_model_configs"
|
||||
|
||||
# list JSON files in directory
|
||||
for file in os.listdir(embedding_config_dir):
|
||||
if file.endswith(".json"):
|
||||
# load JSON file
|
||||
print("testing", file)
|
||||
config_data = json.load(open(os.path.join(embedding_config_dir, file)))
|
||||
embedding_config = EmbeddingConfigModel(**config_data)
|
||||
# model = embedding_model(embedding_config, user_id=uuid.UUID(int=1))
|
||||
model = embedding_model(embedding_config)
|
||||
query_text = "hello"
|
||||
query_vec = model.get_text_embedding(query_text)
|
||||
print("vector dim", len(query_vec))
|
||||
|
||||
|
||||
def test_llm_endpoints():
|
||||
llm_config_dir = "configs/llm_model_configs"
|
||||
|
||||
# use openai default config
|
||||
embedding_config = EmbeddingConfigModel(**json.load(open(embedding_config_path)))
|
||||
|
||||
# list JSON files in directory
|
||||
for file in os.listdir(llm_config_dir):
|
||||
if file.endswith(".json"):
|
||||
# load JSON file
|
||||
print("testing", file)
|
||||
config_data = json.load(open(os.path.join(llm_config_dir, file)))
|
||||
print(config_data)
|
||||
llm_config = LLMConfigModel(**config_data)
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
preset=load_preset("memgpt_chat", user_id=uuid.UUID(int=1)),
|
||||
name="test_agent",
|
||||
created_by=uuid.UUID(int=1),
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True,
|
||||
)
|
||||
|
||||
response = create(
|
||||
llm_config=llm_config,
|
||||
user_id=uuid.UUID(int=1), # dummy user_id
|
||||
# messages=agent_state.messages,
|
||||
messages=agent._messages,
|
||||
functions=agent.functions,
|
||||
functions_python=agent.functions_python,
|
||||
)
|
||||
assert response is not None
|
||||
Reference in New Issue
Block a user