feat: split up endpoint tests (#1432)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
28
.github/workflows/test_local.yml
vendored
Normal file
28
.github/workflows/test_local.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Endpoint (Local)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E local"
|
||||
|
||||
- name: Test embedding endpoint
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_local
|
||||
30
.github/workflows/test_memgpt_hosted.yml
vendored
Normal file
30
.github/workflows/test_memgpt_hosted.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Endpoint (MemGPT)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev"
|
||||
|
||||
- name: Test LLM endpoint
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_memgpt_hosted
|
||||
|
||||
- name: Test embedding endpoint
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_memgpt_hosted
|
||||
43
.github/workflows/test_openai.yml
vendored
Normal file
43
.github/workflows/test_openai.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Endpoint (OpenAI)
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev"
|
||||
|
||||
- name: Initialize credentials
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry run memgpt quickstart --backend openai
|
||||
|
||||
- name: Test LLM endpoint
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_openai
|
||||
|
||||
- name: Test embedding endpoint
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_openai
|
||||
16
.github/workflows/tests.yml
vendored
16
.github/workflows/tests.yml
vendored
@@ -2,7 +2,6 @@ name: Run All pytest Tests
|
||||
|
||||
env:
|
||||
MEMGPT_PGURI: ${{ secrets.MEMGPT_PGURI }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -26,17 +25,11 @@ jobs:
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "--all-extras"
|
||||
install-args: "-E dev -E postgres -E milvus"
|
||||
|
||||
- name: Initialize credentials
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
if [ -z "$OPENAI_API_KEY" ]; then
|
||||
poetry run memgpt quickstart --backend openai
|
||||
else
|
||||
poetry run memgpt quickstart --backend memgpt
|
||||
fi
|
||||
poetry run memgpt quickstart --backend memgpt
|
||||
|
||||
#- name: Run docker compose server
|
||||
# env:
|
||||
@@ -55,7 +48,6 @@ jobs:
|
||||
MEMGPT_PG_PASSWORD: memgpt
|
||||
MEMGPT_PG_DB: memgpt
|
||||
MEMGPT_PG_HOST: localhost
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_server.py
|
||||
@@ -67,11 +59,10 @@ jobs:
|
||||
MEMGPT_PG_PASSWORD: memgpt
|
||||
MEMGPT_PG_HOST: localhost
|
||||
MEMGPT_PG_DB: memgpt
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_storage and not test_server and not test_openai_client" tests
|
||||
poetry run pytest -s -vv -k "not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
|
||||
|
||||
- name: Run storage tests
|
||||
env:
|
||||
@@ -80,7 +71,6 @@ jobs:
|
||||
MEMGPT_PG_PASSWORD: memgpt
|
||||
MEMGPT_PG_HOST: localhost
|
||||
MEMGPT_PG_DB: memgpt
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_storage.py
|
||||
|
||||
7
configs/embedding_model_configs/local.json
Normal file
7
configs/embedding_model_configs/local.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"embedding_endpoint": null,
|
||||
"embedding_model": "BAAI/bge-small-en-v1.5",
|
||||
"embedding_dim": 384,
|
||||
"embedding_chunk_size": 300,
|
||||
"embedding_endpoint_type": "local"
|
||||
}
|
||||
@@ -2,5 +2,6 @@
|
||||
"embedding_endpoint": "https://embeddings.memgpt.ai",
|
||||
"embedding_model": "BAAI/bge-large-en-v1.5",
|
||||
"embedding_dim": 1024,
|
||||
"embedding_chunk_size": 300
|
||||
"embedding_chunk_size": 300,
|
||||
"embedding_endpoint_type": "hugging-face"
|
||||
}
|
||||
|
||||
@@ -476,11 +476,15 @@ class RESTClient(AbstractClient):
|
||||
) -> GetAgentMessagesResponse:
|
||||
params = {"before": before, "after": after, "limit": limit}
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages-cursor", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get messages: {response.text}")
|
||||
return GetAgentMessagesResponse(**response.json())
|
||||
|
||||
def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse:
|
||||
data = {"message": message, "role": role, "stream": stream}
|
||||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to send message: {response.text}")
|
||||
return UserMessageResponse(**response.json())
|
||||
|
||||
# humans / personas
|
||||
|
||||
@@ -162,7 +162,6 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
credentials = MemGPTCredentials.load()
|
||||
|
||||
if endpoint_type == "openai":
|
||||
assert credentials.openai_key is not None
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
additional_kwargs = {"user_id": user_id} if user_id else {}
|
||||
|
||||
@@ -223,14 +223,6 @@ class SyncServer(LockingServer):
|
||||
# TODO figure out how to handle credentials for the server
|
||||
self.credentials = MemGPTCredentials.load()
|
||||
|
||||
# check credentials
|
||||
# TODO: add checks for other providers
|
||||
if (
|
||||
self.config.default_embedding_config.embedding_endpoint_type == "openai"
|
||||
or self.config.default_llm_config.model_endpoint_type == "openai"
|
||||
):
|
||||
assert self.credentials.openai_key is not None, "OpenAI key must be set in the credentials file"
|
||||
|
||||
# Ensure valid database configuration
|
||||
# TODO: add back once tests are matched
|
||||
# assert (
|
||||
|
||||
@@ -7,12 +7,12 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import Preset # TODO move to PresetModel
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig
|
||||
from memgpt.settings import settings
|
||||
from tests.config import TestMGPTConfig
|
||||
from tests.utils import create_config
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
# test_preset_name = "test_preset"
|
||||
@@ -34,54 +34,24 @@ def _reset_config():
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_dim=1536,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
),
|
||||
)
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_model="BAAI/bge-large-en-v1.5",
|
||||
embedding_dim=1024,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="vllm",
|
||||
model_endpoint="https://api.memgpt.ai",
|
||||
model="ehartford/dolphin-2.5-mixtral-8x7b",
|
||||
),
|
||||
)
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# set to use postgres
|
||||
config.archival_storage_uri = db_url
|
||||
config.recall_storage_uri = db_url
|
||||
config.metadata_storage_uri = db_url
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
config.metadata_storage_type = "postgres"
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
print("_reset_config :: ", config.config_path)
|
||||
|
||||
@@ -12,58 +12,74 @@ from memgpt.prompts import gpt_system
|
||||
|
||||
messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")]
|
||||
|
||||
# defaults (memgpt hosted)
|
||||
embedding_config_path = "configs/embedding_model_configs/memgpt-hosted.json"
|
||||
llm_config_path = "configs/llm_model_configs/memgpt-hosted.json"
|
||||
|
||||
|
||||
def test_embedding_endpoints():
|
||||
|
||||
embedding_config_dir = "configs/embedding_model_configs"
|
||||
|
||||
# list JSON files in directory
|
||||
for file in os.listdir(embedding_config_dir):
|
||||
if file.endswith(".json"):
|
||||
# load JSON file
|
||||
print("testing", file)
|
||||
config_data = json.load(open(os.path.join(embedding_config_dir, file)))
|
||||
embedding_config = EmbeddingConfigModel(**config_data)
|
||||
model = embedding_model(embedding_config)
|
||||
query_text = "hello"
|
||||
query_vec = model.get_text_embedding(query_text)
|
||||
print("vector dim", len(query_vec))
|
||||
# directories
|
||||
embedding_config_dir = "configs/embedding_model_configs"
|
||||
llm_config_dir = "configs/llm_model_configs"
|
||||
|
||||
|
||||
def test_llm_endpoints():
|
||||
llm_config_dir = "configs/llm_model_configs"
|
||||
|
||||
# use openai default config
|
||||
def run_llm_endpoint(filename):
|
||||
config_data = json.load(open(filename, "r"))
|
||||
print(config_data)
|
||||
llm_config = LLMConfigModel(**config_data)
|
||||
embedding_config = EmbeddingConfigModel(**json.load(open(embedding_config_path)))
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
preset=load_preset("memgpt_chat", user_id=uuid.UUID(int=1)),
|
||||
name="test_agent",
|
||||
created_by=uuid.UUID(int=1),
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True,
|
||||
)
|
||||
|
||||
# list JSON files in directory
|
||||
for file in os.listdir(llm_config_dir):
|
||||
if file.endswith(".json"):
|
||||
# load JSON file
|
||||
print("testing", file)
|
||||
config_data = json.load(open(os.path.join(llm_config_dir, file)))
|
||||
print(config_data)
|
||||
llm_config = LLMConfigModel(**config_data)
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
preset=load_preset("memgpt_chat", user_id=uuid.UUID(int=1)),
|
||||
name="test_agent",
|
||||
created_by=uuid.UUID(int=1),
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True,
|
||||
)
|
||||
response = create(
|
||||
llm_config=llm_config,
|
||||
user_id=uuid.UUID(int=1), # dummy user_id
|
||||
# messages=agent_state.messages,
|
||||
messages=agent._messages,
|
||||
functions=agent.functions,
|
||||
functions_python=agent.functions_python,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
response = create(
|
||||
llm_config=llm_config,
|
||||
user_id=uuid.UUID(int=1), # dummy user_id
|
||||
# messages=agent_state.messages,
|
||||
messages=agent._messages,
|
||||
functions=agent.functions,
|
||||
functions_python=agent.functions_python,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
def run_embedding_endpoint(filename):
|
||||
# load JSON file
|
||||
config_data = json.load(open(filename, "r"))
|
||||
print(config_data)
|
||||
embedding_config = EmbeddingConfigModel(**config_data)
|
||||
model = embedding_model(embedding_config)
|
||||
query_text = "hello"
|
||||
query_vec = model.get_text_embedding(query_text)
|
||||
print("vector dim", len(query_vec))
|
||||
assert query_vec is not None
|
||||
|
||||
|
||||
def test_llm_endpoint_openai():
|
||||
filename = os.path.join(llm_config_dir, "gpt-4.json")
|
||||
run_llm_endpoint(filename)
|
||||
|
||||
|
||||
def test_embedding_endpoint_openai():
|
||||
filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
|
||||
def test_llm_endpoint_memgpt_hosted():
|
||||
filename = os.path.join(llm_config_dir, "memgpt-hosted.json")
|
||||
run_llm_endpoint(filename)
|
||||
|
||||
|
||||
def test_embedding_endpoint_memgpt_hosted():
|
||||
filename = os.path.join(embedding_config_dir, "memgpt-hosted.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
|
||||
def test_embedding_endpoint_local():
|
||||
filename = os.path.join(embedding_config_dir, "local.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
@@ -6,8 +6,9 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.cli.cli_load import load_directory
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, User
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, User
|
||||
from memgpt.metadata import MetadataStore
|
||||
|
||||
# from memgpt.data_sources.connectors import DirectoryConnector, load_data
|
||||
@@ -16,7 +17,7 @@ from memgpt.settings import settings
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
from .utils import wipe_config
|
||||
from .utils import create_config, wipe_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -46,17 +47,18 @@ def test_load_directory(
|
||||
recreate_declarative_base,
|
||||
):
|
||||
wipe_config()
|
||||
TEST_MEMGPT_CONFIG.default_embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
TEST_MEMGPT_CONFIG.default_llm_config = LLMConfig(
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
)
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
TEST_MEMGPT_CONFIG.default_embedding_config = config.default_embedding_config
|
||||
TEST_MEMGPT_CONFIG.default_llm_config = config.default_llm_config
|
||||
|
||||
# setup config
|
||||
if metadata_storage_connector == "postgres":
|
||||
|
||||
@@ -5,15 +5,14 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import memgpt.utils as utils
|
||||
from tests.config import TestMGPTConfig
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
from .utils import DummyDataConnector, wipe_config, wipe_memgpt_home
|
||||
from .utils import DummyDataConnector, create_config, wipe_config, wipe_memgpt_home
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -24,55 +23,28 @@ def server():
|
||||
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_dim=1536,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
),
|
||||
)
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_model="BAAI/bge-large-en-v1.5",
|
||||
embedding_dim=1024,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="vllm",
|
||||
model_endpoint="https://api.memgpt.ai",
|
||||
model="ehartford/dolphin-2.5-mixtral-8x7b",
|
||||
),
|
||||
)
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# set to use postgres
|
||||
config.archival_storage_uri = db_url
|
||||
config.recall_storage_uri = db_url
|
||||
config.metadata_storage_uri = db_url
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
config.metadata_storage_type = "postgres"
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
|
||||
|
||||
@@ -6,21 +6,16 @@ import pytest
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import (
|
||||
AgentState,
|
||||
EmbeddingConfig,
|
||||
LLMConfig,
|
||||
Message,
|
||||
Passage,
|
||||
User,
|
||||
)
|
||||
from memgpt.data_types import AgentState, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model, query_embedding
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.settings import settings
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from tests.utils import create_config, wipe_config
|
||||
|
||||
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
@@ -124,17 +119,21 @@ def test_storage(
|
||||
# if 'Message' in globals():
|
||||
# print("Removing messages", globals()['Message'])
|
||||
# del globals()['Message']
|
||||
TEST_MEMGPT_CONFIG.default_embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
TEST_MEMGPT_CONFIG.default_llm_config = LLMConfig(
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
wipe_config()
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
create_config("memgpt_hosted")
|
||||
MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
TEST_MEMGPT_CONFIG.default_embedding_config = config.default_embedding_config
|
||||
TEST_MEMGPT_CONFIG.default_llm_config = config.default_llm_config
|
||||
|
||||
if storage_connector == "postgres":
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = settings.memgpt_pg_uri
|
||||
TEST_MEMGPT_CONFIG.recall_storage_uri = settings.memgpt_pg_uri
|
||||
@@ -167,21 +166,8 @@ def test_storage(
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "milvus"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = "./milvus.db"
|
||||
# get embedding model
|
||||
embed_model = None
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
credentials.save()
|
||||
else:
|
||||
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
|
||||
embed_model = embedding_model(embedding_config)
|
||||
embedding_config = TEST_MEMGPT_CONFIG.default_embedding_config
|
||||
embed_model = embedding_model(TEST_MEMGPT_CONFIG.default_embedding_config)
|
||||
|
||||
# create user
|
||||
ms = MetadataStore(TEST_MEMGPT_CONFIG)
|
||||
|
||||
Reference in New Issue
Block a user