diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index 9060d499..fe091fde 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -58,6 +58,7 @@ jobs: pipx install poetry==1.8.2 poetry install -E dev poetry run pytest -s tests/test_client.py + poetry run pytest -s tests/test_concurrent_connections.py - name: Print docker logs if tests fail if: failure() diff --git a/memgpt/client/client.py b/memgpt/client/client.py index bd452363..b0226cd6 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -239,6 +239,7 @@ class RESTClient(AbstractClient): super().__init__(debug=debug) self.base_url = base_url self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"} + self.token = token # agents diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index eca76413..62fd341b 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -2,7 +2,7 @@ import uuid from functools import partial from typing import List -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel, Field from memgpt.models.pydantic_models import ( @@ -60,56 +60,56 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p """ interface.clear() - # try: - agent_state = server.create_agent( - user_id=user_id, - # **request.config - # TODO turn into a pydantic model - name=request.config["name"], - preset=request.config["preset"] if "preset" in request.config else None, - persona_name=request.config["persona_name"] if "persona_name" in request.config else None, - human_name=request.config["human_name"] if "human_name" in request.config else None, - persona=request.config["persona"] if "persona" in request.config else None, - human=request.config["human"] if "human" in request.config else None, - # llm_config=LLMConfigModel( - # model=request.config['model'], - # ) - function_names=request.config["function_names"].split(",") if "function_names" in request.config else None, - ) - llm_config = LLMConfigModel(**vars(agent_state.llm_config)) - embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) + try: + agent_state = server.create_agent( + user_id=user_id, + # **request.config + # TODO turn into a pydantic model + name=request.config["name"], + preset=request.config["preset"] if "preset" in request.config else None, + persona_name=request.config["persona_name"] if "persona_name" in request.config else None, + human_name=request.config["human_name"] if "human_name" in request.config else None, + persona=request.config["persona"] if "persona" in request.config else None, + human=request.config["human"] if "human" in request.config else None, + # llm_config=LLMConfigModel( + # model=request.config['model'], + # ) + function_names=request.config["function_names"].split(",") if "function_names" in request.config else None, + ) + llm_config = LLMConfigModel(**vars(agent_state.llm_config)) + embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) - # TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line - preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id) + # TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line + preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id) - return CreateAgentResponse( - agent_state=AgentStateModel( - id=agent_state.id, - name=agent_state.name, - user_id=agent_state.user_id, - preset=agent_state.preset, - persona=agent_state.persona, - human=agent_state.human, - llm_config=llm_config, - embedding_config=embedding_config, - state=agent_state.state, - created_at=int(agent_state.created_at.timestamp()), - functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead - ), - preset=PresetModel( - name=preset.name, - id=preset.id, - user_id=preset.user_id, - description=preset.description, - created_at=preset.created_at, - system=preset.system, - persona=preset.persona, - human=preset.human, - functions_schema=preset.functions_schema, - ), - ) - # except Exception as e: - # print(str(e)) - # raise HTTPException(status_code=500, detail=str(e)) + return CreateAgentResponse( + agent_state=AgentStateModel( + id=agent_state.id, + name=agent_state.name, + user_id=agent_state.user_id, + preset=agent_state.preset, + persona=agent_state.persona, + human=agent_state.human, + llm_config=llm_config, + embedding_config=embedding_config, + state=agent_state.state, + created_at=int(agent_state.created_at.timestamp()), + functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead + ), + preset=PresetModel( + name=preset.name, + id=preset.id, + user_id=preset.user_id, + description=preset.description, + created_at=preset.created_at, + system=preset.system, + persona=preset.persona, + human=preset.human, + functions_schema=preset.functions_schema, + ), + ) + except Exception as e: + print(str(e)) + raise HTTPException(status_code=500, detail=str(e)) return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 01cb6565..68c44e09 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -146,11 +146,11 @@ class LockingServer(Server): return wrapper - @agent_lock_decorator + # @agent_lock_decorator def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: raise NotImplementedError - @agent_lock_decorator + # @agent_lock_decorator def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: raise NotImplementedError @@ -515,7 +515,7 @@ class SyncServer(LockingServer): input_message = system.get_token_limit_warning() self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) - @LockingServer.agent_lock_decorator + # @LockingServer.agent_lock_decorator def user_message( self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message], timestamp: Optional[datetime] = None ) -> None: @@ -564,7 +564,7 @@ class SyncServer(LockingServer): # Run the agent state forward self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) - @LockingServer.agent_lock_decorator + # @LockingServer.agent_lock_decorator def system_message( self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message], timestamp: Optional[datetime] = None ) -> None: @@ -613,7 +613,7 @@ class SyncServer(LockingServer): # Run the agent state forward self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message) - @LockingServer.agent_lock_decorator + # @LockingServer.agent_lock_decorator def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Run a command on the agent""" if self.ms.get_user(user_id=user_id) is None: @@ -675,7 +675,7 @@ class SyncServer(LockingServer): # NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation # TODO: fix this db dependency and remove - # self.ms.create_agent(agent_state) + # self.ms.#create_agent(agent_state) # TODO modify to do creation via preset try: diff --git a/tests/test_concurrent_connections.py b/tests/test_concurrent_connections.py new file mode 100644 index 00000000..21e1f9a0 --- /dev/null +++ b/tests/test_concurrent_connections.py @@ -0,0 +1,143 @@ +import os +import threading +import time +import uuid + +import pytest +from dotenv import load_dotenv + +from memgpt import Admin, create_client +from memgpt.config import MemGPTConfig +from memgpt.constants import DEFAULT_PRESET +from memgpt.credentials import MemGPTCredentials +from memgpt.data_types import Preset # TODO move to PresetModel +from memgpt.settings import settings +from tests.utils import create_config + +test_agent_name = f"test_client_{str(uuid.uuid4())}" +# test_preset_name = "test_preset" +test_preset_name = DEFAULT_PRESET +test_agent_state = None +client = None + +test_agent_state_post_message = None +test_user_id = uuid.uuid4() + + +# admin credentials +test_server_token = "test_server_token" + + +def _reset_config(): + + # Use os.getenv with a fallback to os.environ.get + db_url = settings.memgpt_pg_uri + + if os.getenv("OPENAI_API_KEY"): + create_config("openai") + credentials = MemGPTCredentials( + openai_key=os.getenv("OPENAI_API_KEY"), + ) + else: # hosted + create_config("memgpt_hosted") + credentials = MemGPTCredentials() + + config = MemGPTConfig.load() + + # set to use postgres + config.archival_storage_uri = db_url + config.recall_storage_uri = db_url + config.metadata_storage_uri = db_url + config.archival_storage_type = "postgres" + config.recall_storage_type = "postgres" + config.metadata_storage_type = "postgres" + + config.save() + credentials.save() + print("_reset_config :: ", config.config_path) + + +def run_server(): + + load_dotenv() + + _reset_config() + + from memgpt.server.rest_api.server import start_server + + print("Starting server...") + start_server(debug=True) + + +# Fixture to create clients with different configurations +@pytest.fixture( + params=[ # whether to use REST API server + {"server": True}, + # {"server": False} # TODO: add when implemented + ], + scope="module", +) +def admin_client(request): + if request.param["server"]: + # get URL from enviornment + server_url = os.getenv("MEMGPT_SERVER_URL") + if server_url is None: + # run server in thread + # NOTE: must set MEMGPT_SERVER_PASS enviornment variable + server_url = "http://localhost:8283" + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + # create user via admin client + admin = Admin(server_url, test_server_token) + response = admin.create_user(test_user_id) # Adjust as per your client's method + + yield admin + + +def test_concurrent_messages(admin_client): + # test concurrent messages + + # create three + + results = [] + + def _send_message(): + try: + print("START SEND MESSAGE") + response = admin_client.create_user() + token = response.api_key + client = create_client(base_url=admin_client.base_url, token=token) + agent = client.create_agent( + name=test_agent_name, + ) + + print("Agent created", agent.id) + + st = time.time() + message = "Hello, how are you?" + response = client.send_message(agent_id=agent.id, message=message, role="user") + et = time.time() + print(f"Message sent from {st} to {et}") + results.append((st, et)) + except Exception as e: + print("ERROR", e) + + threads = [] + print("Starting threads...") + for i in range(2): + thread = threading.Thread(target=_send_message) + threads.append(thread) + thread.start() + print("CREATED THREAD") + + print("waiting for threads to finish...") + for thread in threads: + print(thread.join()) + + # make sure runtime are overlapping + assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or ( + results[1][0] < results[0][0] and results[1][1] > results[0][0] + ), f"Threads should have overlapping runtimes {results}"