From 635fb1cc662c23e6a01fb18379a8f2f3367e9547 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 9 Sep 2024 20:49:59 -0700 Subject: [PATCH] feat: add V1 route refactor from integration branch into separate PR (#1734) --- .github/workflows/black_format.yml | 14 +- .github/workflows/isort_format.yml | 33 +- memgpt/client/admin.py | 29 +- memgpt/client/client.py | 144 +++-- memgpt/metadata.py | 2 +- memgpt/server/rest_api/agents/index.py | 106 ---- memgpt/server/rest_api/agents/memory.py | 148 ----- memgpt/server/rest_api/agents/message.py | 197 ------- memgpt/server/rest_api/block/index.py | 73 --- memgpt/server/rest_api/config/index.py | 40 -- memgpt/server/rest_api/jobs/index.py | 41 -- memgpt/server/rest_api/models/index.py | 38 -- .../rest_api/openai_assistants/assistants.py | 488 ---------------- .../chat_completions.py | 129 ----- memgpt/server/rest_api/personas/__init__.py | 0 memgpt/server/rest_api/personas/index.py | 70 --- .../rest_api/{agents => routers}/__init__.py | 0 .../{block => routers/openai}/__init__.py | 0 .../openai/assistants}/__init__.py | 0 .../routers/openai/assistants/assistants.py | 115 ++++ .../routers/openai/assistants/schemas.py | 121 ++++ .../routers/openai/assistants/threads.py | 336 +++++++++++ .../openai/chat_completions}/__init__.py | 0 .../chat_completions/chat_completions.py | 131 +++++ memgpt/server/rest_api/routers/v1/__init__.py | 15 + memgpt/server/rest_api/routers/v1/agents.py | 529 ++++++++++++++++++ memgpt/server/rest_api/routers/v1/blocks.py | 73 +++ memgpt/server/rest_api/routers/v1/jobs.py | 46 ++ memgpt/server/rest_api/routers/v1/llms.py | 28 + memgpt/server/rest_api/routers/v1/sources.py | 199 +++++++ memgpt/server/rest_api/routers/v1/tools.py | 103 ++++ memgpt/server/rest_api/routers/v1/users.py | 109 ++++ memgpt/server/rest_api/server.py | 97 ++-- memgpt/server/rest_api/sources/__init__.py | 0 memgpt/server/rest_api/sources/index.py | 262 --------- memgpt/server/rest_api/static_files.py | 3 +- memgpt/server/rest_api/tools/__init__.py | 0 memgpt/server/rest_api/tools/index.py | 98 ---- memgpt/server/rest_api/utils.py | 13 + memgpt/server/server.py | 45 +- tests/test_tools.py | 8 +- 41 files changed, 2053 insertions(+), 1830 deletions(-) delete mode 100644 memgpt/server/rest_api/agents/index.py delete mode 100644 memgpt/server/rest_api/agents/memory.py delete mode 100644 memgpt/server/rest_api/agents/message.py delete mode 100644 memgpt/server/rest_api/block/index.py delete mode 100644 memgpt/server/rest_api/config/index.py delete mode 100644 memgpt/server/rest_api/jobs/index.py delete mode 100644 memgpt/server/rest_api/models/index.py delete mode 100644 memgpt/server/rest_api/openai_assistants/assistants.py delete mode 100644 memgpt/server/rest_api/openai_chat_completions/chat_completions.py delete mode 100644 memgpt/server/rest_api/personas/__init__.py delete mode 100644 memgpt/server/rest_api/personas/index.py rename memgpt/server/rest_api/{agents => routers}/__init__.py (100%) rename memgpt/server/rest_api/{block => routers/openai}/__init__.py (100%) rename memgpt/server/rest_api/{config => routers/openai/assistants}/__init__.py (100%) create mode 100644 memgpt/server/rest_api/routers/openai/assistants/assistants.py create mode 100644 memgpt/server/rest_api/routers/openai/assistants/schemas.py create mode 100644 memgpt/server/rest_api/routers/openai/assistants/threads.py rename memgpt/server/rest_api/{models => routers/openai/chat_completions}/__init__.py (100%) create mode 100644 memgpt/server/rest_api/routers/openai/chat_completions/chat_completions.py create mode 100644 memgpt/server/rest_api/routers/v1/__init__.py create mode 100644 memgpt/server/rest_api/routers/v1/agents.py create mode 100644 memgpt/server/rest_api/routers/v1/blocks.py create mode 100644 memgpt/server/rest_api/routers/v1/jobs.py create mode 100644 memgpt/server/rest_api/routers/v1/llms.py create mode 100644 memgpt/server/rest_api/routers/v1/sources.py create mode 100644 memgpt/server/rest_api/routers/v1/tools.py create mode 100644 memgpt/server/rest_api/routers/v1/users.py delete mode 100644 memgpt/server/rest_api/sources/__init__.py delete mode 100644 memgpt/server/rest_api/sources/index.py delete mode 100644 memgpt/server/rest_api/tools/__init__.py delete mode 100644 memgpt/server/rest_api/tools/index.py diff --git a/.github/workflows/black_format.yml b/.github/workflows/black_format.yml index 1b8497f4..5aeea4d5 100644 --- a/.github/workflows/black_format.yml +++ b/.github/workflows/black_format.yml @@ -1,38 +1,36 @@ name: Code Formatter (Black) - on: pull_request: paths: - '**.py' workflow_dispatch: - jobs: black-check: runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 - + with: + ref: ${{ github.head_ref }} # Checkout the PR branch + fetch-depth: 0 # Fetch all history for all branches and tags - name: "Setup Python, Poetry and Dependencies" uses: packetcoders/action-setup-cache-python-poetry@main with: python-version: "3.12" poetry-version: "1.8.2" install-args: "-E dev" # TODO: change this to --group dev when PR #842 lands - - name: Run Black id: black run: poetry run black --check . continue-on-error: true - - name: Auto-fix with Black and commit - if: steps.black.outcome == 'failure' || ${{ !contains(github.event.issue.labels.*.name, 'check-only') }} + if: steps.black.outcome == 'failure' && !contains(github.event.pull_request.labels.*.name, 'check-only') run: | poetry run black . git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git commit -am "Apply Black formatting" || echo "No changes to commit" + git diff --quiet && git diff --staged --quiet || (git add -A && git commit -m "Apply Black formatting") git push - name: Error if 'check-only' label is present - if: steps.black.outcome == 'failure' && contains(github.event.issue.labels.*.name, 'check-only') + if: steps.black.outcome == 'failure' && contains(github.event.pull_request.labels.*.name, 'check-only') run: echo "Black formatting check failed. Please run 'black .' locally to fix formatting issues." && exit 1 diff --git a/.github/workflows/isort_format.yml b/.github/workflows/isort_format.yml index d31b7b68..670ec01e 100644 --- a/.github/workflows/isort_format.yml +++ b/.github/workflows/isort_format.yml @@ -1,39 +1,48 @@ name: Code Formatter (isort) - on: pull_request: paths: - '**.py' workflow_dispatch: - jobs: isort-check: runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 - + with: + ref: ${{ github.head_ref }} # Checkout the PR branch + fetch-depth: 0 # Fetch all history for all branches and tags - name: "Setup Python, Poetry and Dependencies" uses: packetcoders/action-setup-cache-python-poetry@main with: python-version: "3.12" poetry-version: "1.8.2" install-args: "-E dev" # TODO: change this to --group dev when PR #842 lands - - name: Run isort id: isort - run: poetry run isort --profile black --check-only . + run: | + output=$(poetry run isort --profile black --check-only --diff . | grep -v "Skipped" || true) + echo "$output" + if [ -n "$output" ]; then + echo "isort_changed=true" >> $GITHUB_OUTPUT + else + echo "isort_changed=false" >> $GITHUB_OUTPUT + fi continue-on-error: true - - name: Auto-fix with isort and commit - if: steps.isort.outcome == 'failure' || ${{ !contains(github.event.issue.labels.*.name, 'check-only') }} + if: steps.isort.outputs.isort_changed == 'true' && !contains(github.event.pull_request.labels.*.name, 'check-only') run: | poetry run isort --profile black . git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git commit -am "Apply isort import ordering" || echo "No changes to commit" - git push - - name: Error if 'check-only' label is present - if: steps.isort.outcome == 'failure' && contains(github.event.issue.labels.*.name, 'check-only') + if [[ -n $(git status -s) ]]; then + git add -A + git commit -m "Apply isort import ordering" + git push + else + echo "No changes to commit" + fi + - name: Error if 'check-only' label is present and changes are needed + if: steps.isort.outputs.isort_changed == 'true' && contains(github.event.pull_request.labels.*.name, 'check-only') run: echo "Isort check failed. Please run 'isort .' locally to fix import orders." && exit 1 - diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index c64b82c8..09293472 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -16,8 +16,14 @@ class Admin: - Generating user keys """ - def __init__(self, base_url: str, token: str): + def __init__( + self, + base_url: str, + token: str, + api_prefix: str = "v1", + ): self.base_url = base_url + self.api_prefix = api_prefix self.token = token self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"} @@ -27,35 +33,36 @@ class Admin: params["cursor"] = str(cursor) if limit: params["limit"] = limit - response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/users", params=params, headers=self.headers) if response.status_code != 200: raise HTTPError(response.json()) return [User(**user) for user in response.json()] def create_key(self, user_id: str, key_name: Optional[str] = None) -> APIKey: request = APIKeyCreate(user_id=user_id, name=key_name) - response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=request.model_dump()) + response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users/keys", headers=self.headers, json=request.model_dump()) if response.status_code != 200: raise HTTPError(response.json()) return APIKey(**response.json()) def get_keys(self, user_id: str) -> List[APIKey]: params = {"user_id": str(user_id)} - response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/users/keys", params=params, headers=self.headers) if response.status_code != 200: raise HTTPError(response.json()) return [APIKey(**key) for key in response.json()] def delete_key(self, api_key: str) -> APIKey: params = {"api_key": api_key} - response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/users/keys", params=params, headers=self.headers) if response.status_code != 200: raise HTTPError(response.json()) return APIKey(**response.json()) def create_user(self, name: Optional[str] = None) -> User: request = UserCreate(name=name) - response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=request.model_dump()) + print("YYYYY Pinging", f"{self.base_url}/{self.api_prefix}/admin/users") + response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump()) if response.status_code != 200: raise HTTPError(response.json()) response_json = response.json() @@ -63,7 +70,7 @@ class Admin: def delete_user(self, user_id: str) -> User: params = {"user_id": str(user_id)} - response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/users", params=params, headers=self.headers) if response.status_code != 200: raise HTTPError(response.json()) return User(**response.json()) @@ -114,23 +121,23 @@ class Admin: CreateToolRequest(**data) # validate # make REST request - response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/tools", json=data, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create tool: {response.text}") return ToolModel(**response.json()) def list_tools(self): - response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/tools", headers=self.headers) return ListToolsResponse(**response.json()).tools def delete_tool(self, name: str): - response = requests.delete(f"{self.base_url}/admin/tools/{name}", headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/tools/{name}", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to delete tool: {response.text}") return response.json() def get_tool(self, name: str): - response = requests.get(f"{self.base_url}/admin/tools/{name}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/tools/{name}", headers=self.headers) if response.status_code == 404: return None elif response.status_code != 200: diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 69f3884e..d6966bcb 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -1,6 +1,6 @@ import logging import time -from typing import Dict, Generator, List, Optional, Union +from typing import Callable, Dict, Generator, List, Optional, Union import requests @@ -185,7 +185,7 @@ class AbstractClient(object): self, id: str, name: Optional[str] = None, - func: Optional[callable] = None, + func: Optional[Callable] = None, tags: Optional[List[str]] = None, ) -> Tool: raise NotImplementedError @@ -271,6 +271,7 @@ class RESTClient(AbstractClient): self, base_url: str, token: str, + api_prefix: str = "v1", debug: bool = False, ): """ @@ -283,10 +284,11 @@ class RESTClient(AbstractClient): """ super().__init__(debug=debug) self.base_url = base_url + self.api_prefix = api_prefix self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"} def list_agents(self) -> List[AgentState]: - response = requests.get(f"{self.base_url}/api/agents", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers) return [AgentState(**agent) for agent in response.json()] def agent_exists(self, agent_id: str) -> bool: @@ -301,7 +303,7 @@ class RESTClient(AbstractClient): exists (bool): `True` if the agent exists, `False` otherwise """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", headers=self.headers) if response.status_code == 404: # not found error return False @@ -375,7 +377,7 @@ class RESTClient(AbstractClient): embedding_config=embedding_config, ) - response = requests.post(f"{self.base_url}/api/agents", json=request.model_dump(), headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}") return AgentState(**response.json()) @@ -399,7 +401,7 @@ class RESTClient(AbstractClient): tool_call_id=tool_call_id, ) response = requests.patch( - f"{self.base_url}/api/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers ) if response.status_code != 200: raise ValueError(f"Failed to update message: {response.text}") @@ -448,7 +450,7 @@ class RESTClient(AbstractClient): message_ids=message_ids, memory=memory, ) - response = requests.post(f"{self.base_url}/api/agents/{agent_id}", json=request.model_dump(), headers=self.headers) + response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update agent: {response.text}") return AgentState(**response.json()) @@ -471,7 +473,7 @@ class RESTClient(AbstractClient): Args: agent_id (str): ID of the agent to delete """ - response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/agents/{str(agent_id)}", headers=self.headers) assert response.status_code == 200, f"Failed to delete agent: {response.text}" def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState: @@ -484,7 +486,7 @@ class RESTClient(AbstractClient): Returns: agent_state (AgentState): State representation of the agent """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", headers=self.headers) assert response.status_code == 200, f"Failed to get agent: {response.text}" return AgentState(**response.json()) @@ -512,7 +514,7 @@ class RESTClient(AbstractClient): Returns: memory (Memory): In-context memory of the agent """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get in-context memory: {response.text}") return Memory(**response.json()) @@ -529,7 +531,9 @@ class RESTClient(AbstractClient): """ memory_update_dict = {section: value} - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/memory", json=memory_update_dict, headers=self.headers) + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory", json=memory_update_dict, headers=self.headers + ) if response.status_code != 200: raise ValueError(f"Failed to update in-context memory: {response.text}") return Memory(**response.json()) @@ -545,7 +549,7 @@ class RESTClient(AbstractClient): summary (ArchivalMemorySummary): Summary of the archival memory """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/archival", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/archival", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get archival memory summary: {response.text}") return ArchivalMemorySummary(**response.json()) @@ -560,7 +564,7 @@ class RESTClient(AbstractClient): Returns: summary (RecallMemorySummary): Summary of the recall memory """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/recall", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/recall", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get recall memory summary: {response.text}") return RecallMemorySummary(**response.json()) @@ -575,7 +579,7 @@ class RESTClient(AbstractClient): Returns: messages (List[Message]): List of in-context messages """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/messages", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/messages", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get in-context messages: {response.text}") return [Message(**message) for message in response.json()] @@ -620,7 +624,7 @@ class RESTClient(AbstractClient): params["before"] = str(before) if after: params["after"] = str(after) - response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{str(agent_id)}/archival", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to get archival memory: {response.text}" return [Passage(**passage) for passage in response.json()] @@ -636,7 +640,9 @@ class RESTClient(AbstractClient): passages (List[Passage]): List of inserted passages """ request = CreateArchivalMemory(text=memory) - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", headers=self.headers, json=request.model_dump()) + response = requests.post( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/archival", headers=self.headers, json=request.model_dump() + ) if response.status_code != 200: raise ValueError(f"Failed to insert archival memory: {response.text}") return [Passage(**passage) for passage in response.json()] @@ -649,7 +655,7 @@ class RESTClient(AbstractClient): agent_id (str): ID of the agent memory_id (str): ID of the memory """ - response = requests.delete(f"{self.base_url}/api/agents/{agent_id}/archival/{memory_id}", headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/archival/{memory_id}", headers=self.headers) assert response.status_code == 200, f"Failed to delete archival memory: {response.text}" # messages (recall memory) @@ -671,7 +677,7 @@ class RESTClient(AbstractClient): """ params = {"before": before, "after": after, "limit": limit, "msg_object": True} - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages", params=params, headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", params=params, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get messages: {response.text}") return [Message(**message) for message in response.json()] @@ -708,9 +714,11 @@ class RESTClient(AbstractClient): from memgpt.client.streaming import _sse_post request.return_message_object = False - return _sse_post(f"{self.base_url}/api/agents/{agent_id}/messages", request.model_dump(), self.headers) + return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", request.model_dump(), self.headers) else: - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers) + response = requests.post( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers + ) if response.status_code != 200: raise ValueError(f"Failed to send message: {response.text}") return MemGPTResponse(**response.json()) @@ -719,7 +727,7 @@ class RESTClient(AbstractClient): def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]: params = {"label": label, "templates_only": templates_only} - response = requests.get(f"{self.base_url}/api/blocks", params=params, headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks", params=params, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list blocks: {response.text}") @@ -732,7 +740,7 @@ class RESTClient(AbstractClient): def create_block(self, label: str, name: str, text: str) -> Block: # request = CreateBlock(label=label, name=name, value=text) - response = requests.post(f"{self.base_url}/api/blocks", json=request.model_dump(), headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create block: {response.text}") if request.label == "human": @@ -744,13 +752,13 @@ class RESTClient(AbstractClient): def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block: request = UpdateBlock(id=block_id, name=name, value=text) - response = requests.post(f"{self.base_url}/api/blocks/{block_id}", json=request.model_dump(), headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update block: {response.text}") return Block(**response.json()) def get_block(self, block_id: str) -> Block: - response = requests.get(f"{self.base_url}/api/blocks/{block_id}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers) if response.status_code == 404: return None elif response.status_code != 200: @@ -759,7 +767,7 @@ class RESTClient(AbstractClient): def get_block_id(self, name: str, label: str) -> str: params = {"name": name, "label": label} - response = requests.get(f"{self.base_url}/api/blocks", params=params, headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks", params=params, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get block ID: {response.text}") blocks = [Block(**block) for block in response.json()] @@ -770,7 +778,7 @@ class RESTClient(AbstractClient): return blocks[0].id def delete_block(self, id: str) -> Block: - response = requests.delete(f"{self.base_url}/api/blocks/{id}", headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/blocks/{id}", headers=self.headers) assert response.status_code == 200, f"Failed to delete block: {response.text}" if response.status_code != 200: raise ValueError(f"Failed to delete block: {response.text}") @@ -811,7 +819,7 @@ class RESTClient(AbstractClient): human (Human): Updated human block """ request = UpdateHuman(id=human_id, name=name, value=text) - response = requests.post(f"{self.base_url}/api/blocks/{human_id}", json=request.model_dump(), headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{human_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update human: {response.text}") return Human(**response.json()) @@ -851,7 +859,7 @@ class RESTClient(AbstractClient): persona (Persona): Updated persona block """ request = UpdatePersona(id=persona_id, name=name, value=text) - response = requests.post(f"{self.base_url}/api/blocks/{persona_id}", json=request.model_dump(), headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{persona_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update persona: {response.text}") return Persona(**response.json()) @@ -934,7 +942,7 @@ class RESTClient(AbstractClient): Returns: source (Source): Source """ - response = requests.get(f"{self.base_url}/api/sources/{source_id}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get source: {response.text}") return Source(**response.json()) @@ -949,7 +957,7 @@ class RESTClient(AbstractClient): Returns: source_id (str): ID of the source """ - response = requests.get(f"{self.base_url}/api/sources/name/{source_name}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/name/{source_name}", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get source ID: {response.text}") return response.json() @@ -961,7 +969,7 @@ class RESTClient(AbstractClient): Returns: sources (List[Source]): List of sources """ - response = requests.get(f"{self.base_url}/api/sources", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/sources", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list sources: {response.text}") return [Source(**source) for source in response.json()] @@ -973,21 +981,21 @@ class RESTClient(AbstractClient): Args: source_id (str): ID of the source """ - response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/sources/{str(source_id)}", headers=self.headers) assert response.status_code == 200, f"Failed to delete source: {response.text}" def get_job(self, job_id: str) -> Job: - response = requests.get(f"{self.base_url}/api/jobs/{job_id}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get job: {response.text}") return Job(**response.json()) def list_jobs(self): - response = requests.get(f"{self.base_url}/api/jobs", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers) return [Job(**job) for job in response.json()] def list_active_jobs(self): - response = requests.get(f"{self.base_url}/api/jobs/active", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs/active", headers=self.headers) return [Job(**job) for job in response.json()] def load_data(self, connector: DataConnector, source_name: str): @@ -1008,7 +1016,7 @@ class RESTClient(AbstractClient): files = {"file": open(filename, "rb")} # create job - response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/upload", files=files, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to upload file to source: {response.text}") @@ -1035,7 +1043,7 @@ class RESTClient(AbstractClient): source (Source): Created source """ payload = {"name": name} - response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers) response_json = response.json() return Source(**response_json) @@ -1049,7 +1057,7 @@ class RESTClient(AbstractClient): Returns: sources (List[Source]): List of sources """ - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/sources", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list attached sources: {response.text}") return [Source(**source) for source in response.json()] @@ -1066,7 +1074,7 @@ class RESTClient(AbstractClient): source (Source): Updated source """ request = SourceUpdate(id=source_id, name=name) - response = requests.post(f"{self.base_url}/api/sources/{source_id}", json=request.model_dump(), headers=self.headers) + response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update source: {response.text}") return Source(**response.json()) @@ -1081,13 +1089,13 @@ class RESTClient(AbstractClient): source_name (str): Name of the source """ params = {"agent_id": agent_id} - response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/attach", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to attach source to agent: {response.text}" def detach_source(self, source_id: str, agent_id: str): """Detach a source from an agent""" params = {"agent_id": str(agent_id)} - response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/detach", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" # server configuration commands @@ -1099,7 +1107,7 @@ class RESTClient(AbstractClient): Returns: models (List[LLMConfig]): List of LLM models """ - response = requests.get(f"{self.base_url}/api/config/llm", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list models: {response.text}") return [LLMConfig(**model) for model in response.json()] @@ -1111,7 +1119,7 @@ class RESTClient(AbstractClient): Returns: models (List[EmbeddingConfig]): List of embedding models """ - response = requests.get(f"{self.base_url}/api/config/embedding", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list embedding models: {response.text}") return [EmbeddingConfig(**model) for model in response.json()] @@ -1128,7 +1136,7 @@ class RESTClient(AbstractClient): Returns: id (str): ID of the tool (`None` if not found) """ - response = requests.get(f"{self.base_url}/api/tools/name/{tool_name}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{tool_name}", headers=self.headers) if response.status_code == 404: return None elif response.status_code != 200: @@ -1137,7 +1145,7 @@ class RESTClient(AbstractClient): def create_tool( self, - func, + func: Callable, name: Optional[str] = None, update: Optional[bool] = True, # TODO: actually use this tags: Optional[List[str]] = None, @@ -1157,14 +1165,21 @@ class RESTClient(AbstractClient): # TODO: check tool update code # TODO: check if tool already exists + # TODO: how to load modules? # parse source code/schema source_code = parse_source_code(func) source_type = "python" + # TODO: Check if tool already exists + # if name: + # tool_id = self.get_tool_id(tool_name=name) + # if tool_id: + # raise ValueError(f"Tool with name {name} (id={tool_id}) already exists") + # call server function request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags) - response = requests.post(f"{self.base_url}/api/tools", json=request.model_dump(), headers=self.headers) + response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create tool: {response.text}") return Tool(**response.json()) @@ -1173,7 +1188,7 @@ class RESTClient(AbstractClient): self, id: str, name: Optional[str] = None, - func: Optional[callable] = None, + func: Optional[Callable] = None, tags: Optional[List[str]] = None, ) -> Tool: """ @@ -1196,7 +1211,7 @@ class RESTClient(AbstractClient): source_type = "python" request = ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name) - response = requests.post(f"{self.base_url}/api/tools/{id}", json=request.model_dump(), headers=self.headers) + response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update tool: {response.text}") return Tool(**response.json()) @@ -1235,7 +1250,7 @@ class RESTClient(AbstractClient): # raise ValueError(f"Failed to create tool: {e}, invalid input {data}") # # make REST request - # response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers) + # response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=data, headers=self.headers) # if response.status_code != 200: # raise ValueError(f"Failed to create tool: {response.text}") # return ToolModel(**response.json()) @@ -1247,7 +1262,7 @@ class RESTClient(AbstractClient): Returns: tools (List[Tool]): List of tools """ - response = requests.get(f"{self.base_url}/api/tools", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list tools: {response.text}") return [Tool(**tool) for tool in response.json()] @@ -1259,11 +1274,11 @@ class RESTClient(AbstractClient): Args: id (str): ID of the tool """ - response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers) + response = requests.delete(f"{self.base_url}/{self.api_prefix}/tools/{name}", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to delete tool: {response.text}") - def get_tool(self, name: str): + def get_tool(self, id: str) -> Optional[Tool]: """ Get a tool give its ID. @@ -1273,13 +1288,30 @@ class RESTClient(AbstractClient): Returns: tool (Tool): Tool """ - response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/{id}", headers=self.headers) if response.status_code == 404: return None elif response.status_code != 200: raise ValueError(f"Failed to get tool: {response.text}") return Tool(**response.json()) + def get_tool_id(self, name: str) -> Optional[str]: + """ + Get a tool ID by its name. + + Args: + id (str): ID of the tool + + Returns: + tool (Tool): Tool + """ + response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers) + if response.status_code == 404: + return None + elif response.status_code != 200: + raise ValueError(f"Failed to get tool: {response.text}") + return response.json() + class LocalClient(AbstractClient): """ @@ -1972,7 +2004,7 @@ class LocalClient(AbstractClient): tools = self.server.list_tools(user_id=self.user_id) return tools - def get_tool(self, id: str) -> Tool: + def get_tool(self, id: str) -> Optional[Tool]: """ Get a tool give its ID. @@ -2222,7 +2254,7 @@ class LocalClient(AbstractClient): Returns: models (List[LLMConfig]): List of LLM models """ - return [self.server.server_llm_config] + return self.server.list_models() def list_embedding_models(self) -> List[EmbeddingConfig]: """ @@ -2231,7 +2263,7 @@ class LocalClient(AbstractClient): Returns: models (List[EmbeddingConfig]): List of embedding models """ - return [self.server.server_embedding_config] + return self.server.list_embedding_models() def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]: """ diff --git a/memgpt/metadata.py b/memgpt/metadata.py index ddabb584..106d35bd 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -744,7 +744,7 @@ class MetadataStore: template: Optional[bool] = None, name: Optional[str] = None, id: Optional[str] = None, - ) -> List[Block]: + ) -> Optional[List[Block]]: """List available blocks""" with self.session_maker() as session: query = session.query(BlockModel) diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py deleted file mode 100644 index bfbcfb81..00000000 --- a/memgpt/server/rest_api/agents/index.py +++ /dev/null @@ -1,106 +0,0 @@ -from functools import partial -from typing import List - -from fastapi import APIRouter, Body, Depends, HTTPException - -from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState -from memgpt.schemas.source import Source -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/agents", tags=["agents"], response_model=List[AgentState]) - def list_agents( - user_id: str = Depends(get_current_user_with_server), - ): - """ - List all agents associated with a given user. - - This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. - """ - interface.clear() - agents_data = server.list_agents(user_id=user_id) - return agents_data - - @router.post("/agents", tags=["agents"], response_model=AgentState) - def create_agent( - request: CreateAgent = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Create a new agent with the specified configuration. - """ - interface.clear() - - agent_state = server.create_agent(request, user_id=user_id) - return agent_state - - @router.post("/agents/{agent_id}", tags=["agents"], response_model=AgentState) - def update_agent( - agent_id: str, - request: UpdateAgentState = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """Update an exsiting agent""" - interface.clear() - try: - # TODO: should id be moved out of UpdateAgentState? - agent_state = server.update_agent(request, user_id=user_id) - except Exception as e: - print(str(e)) - raise HTTPException(status_code=500, detail=str(e)) - - return agent_state - - @router.get("/agents/{agent_id}", tags=["agents"], response_model=AgentState) - def get_agent_state( - agent_id: str = None, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get the state of the agent. - """ - - interface.clear() - if not server.ms.get_agent(user_id=user_id, agent_id=agent_id): - # agent does not exist - raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") - - return server.get_agent_state(user_id=user_id, agent_id=agent_id) - - @router.delete("/agents/{agent_id}", tags=["agents"]) - def delete_agent( - agent_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Delete an agent. - """ - # agent_id = str(agent_id) - - interface.clear() - try: - server.delete_agent(user_id=user_id, agent_id=agent_id) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") - - @router.get("/agents/{agent_id}/sources", tags=["agents"], response_model=List[Source]) - def get_agent_sources( - agent_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get the sources associated with an agent. - """ - interface.clear() - return server.list_attached_sources(agent_id) - - return router diff --git a/memgpt/server/rest_api/agents/memory.py b/memgpt/server/rest_api/agents/memory.py deleted file mode 100644 index 13eed748..00000000 --- a/memgpt/server/rest_api/agents/memory.py +++ /dev/null @@ -1,148 +0,0 @@ -from functools import partial -from typing import Dict, List, Optional - -from fastapi import APIRouter, Body, Depends, HTTPException, Query, status -from fastapi.responses import JSONResponse - -from memgpt.schemas.memory import ( - ArchivalMemorySummary, - CreateArchivalMemory, - Memory, - RecallMemorySummary, -) -from memgpt.schemas.message import Message -from memgpt.schemas.passage import Passage -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/agents/{agent_id}/memory/messages", tags=["agents"], response_model=List[Message]) - def get_agent_in_context_messages( - agent_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve the messages in the context of a specific agent. - """ - interface.clear() - return server.get_in_context_messages(agent_id=agent_id) - - @router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory) - def get_agent_memory( - agent_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve the memory state of a specific agent. - - This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. - """ - interface.clear() - return server.get_agent_memory(agent_id=agent_id) - - @router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory) - def update_agent_memory( - agent_id: str, - request: Dict = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Update the core memory of a specific agent. - - This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. - """ - interface.clear() - memory = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=request) - return memory - - @router.get("/agents/{agent_id}/memory/recall", tags=["agents"], response_model=RecallMemorySummary) - def get_agent_recall_memory_summary( - agent_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve the summary of the recall memory of a specific agent. - """ - interface.clear() - return server.get_recall_memory_summary(agent_id=agent_id) - - @router.get("/agents/{agent_id}/memory/archival", tags=["agents"], response_model=ArchivalMemorySummary) - def get_agent_archival_memory_summary( - agent_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve the summary of the archival memory of a specific agent. - """ - interface.clear() - return server.get_archival_memory_summary(agent_id=agent_id) - - # @router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=List[Passage]) - # def get_agent_archival_memory_all( - # agent_id: str, - # user_id: str = Depends(get_current_user_with_server), - # ): - # """ - # Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once). - # """ - # interface.clear() - # return server.get_all_archival_memories(user_id=user_id, agent_id=agent_id) - - @router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage]) - def get_agent_archival_memory( - agent_id: str, - after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."), - before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."), - limit: Optional[int] = Query(None, description="How many results to include in the response."), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve the memories in an agent's archival memory store (paginated query). - """ - interface.clear() - return server.get_agent_archival_cursor( - user_id=user_id, - agent_id=agent_id, - after=after, - before=before, - limit=limit, - ) - - @router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage]) - def insert_agent_archival_memory( - agent_id: str, - request: CreateArchivalMemory = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Insert a memory into an agent's archival memory store. - """ - interface.clear() - return server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.text) - - @router.delete("/agents/{agent_id}/archival/{memory_id}", tags=["agents"]) - def delete_agent_archival_memory( - agent_id: str, - memory_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Delete a memory from an agent's archival memory store. - """ - # TODO: should probably return a `Passage` - interface.clear() - try: - server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id) - return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") - - return router diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py deleted file mode 100644 index 89663183..00000000 --- a/memgpt/server/rest_api/agents/message.py +++ /dev/null @@ -1,197 +0,0 @@ -import asyncio -from datetime import datetime -from functools import partial -from typing import List, Optional, Union - -from fastapi import APIRouter, Body, Depends, HTTPException, Query -from fastapi.responses import StreamingResponse - -from memgpt.schemas.enums import MessageRole, MessageStreamStatus -from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage -from memgpt.schemas.memgpt_request import MemGPTRequest -from memgpt.schemas.memgpt_response import MemGPTResponse -from memgpt.schemas.message import Message, UpdateMessage -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface -from memgpt.server.rest_api.utils import sse_async_generator -from memgpt.server.server import SyncServer -from memgpt.utils import deduplicate - -router = APIRouter() - - -# TODO: cpacker should check this file -# TODO: move this into server.py? -async def send_message_to_agent( - server: SyncServer, - agent_id: str, - user_id: str, - role: MessageRole, - message: str, - stream_steps: bool, - stream_tokens: bool, - return_message_object: bool, # Should be True for Python Client, False for REST API - chat_completion_mode: Optional[bool] = False, - timestamp: Optional[datetime] = None, - # related to whether or not we return `MemGPTMessage`s or `Message`s -) -> Union[StreamingResponse, MemGPTResponse]: - """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" - # TODO: @charles is this the correct way to handle? - include_final_message = True - - # determine role - if role == MessageRole.user: - message_func = server.user_message - elif role == MessageRole.system: - message_func = server.system_message - else: - raise HTTPException(status_code=500, detail=f"Bad role {role}") - - if not stream_steps and stream_tokens: - raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'") - - # For streaming response - try: - - # TODO: move this logic into server.py - - # Get the generator object off of the agent's streaming interface - # This will be attached to the POST SSE request used under-the-hood - memgpt_agent = server._get_or_load_agent(agent_id=agent_id) - streaming_interface = memgpt_agent.interface - if not isinstance(streaming_interface, StreamingServerInterface): - raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") - - # Enable token-streaming within the request if desired - streaming_interface.streaming_mode = stream_tokens - # "chatcompletion mode" does some remapping and ignores inner thoughts - streaming_interface.streaming_chat_completion_mode = chat_completion_mode - - # streaming_interface.allow_assistant_message = stream - # streaming_interface.function_call_legacy_mode = stream - - # Offload the synchronous message_func to a separate thread - streaming_interface.stream_start() - task = asyncio.create_task( - asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp) - ) - - if stream_steps: - if return_message_object: - # TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format - raise NotImplementedError - - # return a stream - return StreamingResponse( - sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message), - media_type="text/event-stream", - ) - - else: - # buffer the stream, then return the list - generated_stream = [] - async for message in streaming_interface.get_generator(): - assert ( - isinstance(message, MemGPTMessage) - or isinstance(message, LegacyMemGPTMessage) - or isinstance(message, MessageStreamStatus) - ), type(message) - generated_stream.append(message) - if message == MessageStreamStatus.done: - break - - # Get rid of the stream status messages - filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] - usage = await task - - # By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage - # If we want to convert these to Message, we can use the attached IDs - # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) - # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead - if return_message_object: - message_ids = [m.id for m in filtered_stream] - message_ids = deduplicate(message_ids) - message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] - return MemGPTResponse(messages=message_objs, usage=usage) - else: - return MemGPTResponse(messages=filtered_stream, usage=usage) - - except HTTPException: - raise - except Exception as e: - print(e) - import traceback - - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"{e}") - - -def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message]) - def get_agent_messages( - agent_id: str, - before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), - limit: int = Query(10, description="Maximum number of messages to retrieve."), - msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve message history for an agent. - """ - interface.clear() - return server.get_agent_recall_cursor( - user_id=user_id, - agent_id=agent_id, - before=before, - limit=limit, - reverse=True, - return_message_object=msg_object, - ) - - @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse) - async def send_message( - # background_tasks: BackgroundTasks, - agent_id: str, - request: MemGPTRequest = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Process a user message and return the agent's response. - - This endpoint accepts a message from a user and processes it through the agent. - It can optionally stream the response if 'stream' is set to True. - """ - # TODO: should this recieve multiple messages? @cpacker - # TODO: revise to `MemGPTRequest` - # TODO: support sending multiple messages - assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}" - message = request.messages[0] - - # TODO: what to do with message.name? - return await send_message_to_agent( - server=server, - agent_id=agent_id, - user_id=user_id, - role=message.role, - message=message.text, - stream_steps=request.stream_steps, - stream_tokens=request.stream_tokens, - return_message_object=request.return_message_object, - ) - - @router.patch("/agents/{agent_id}/messages/{message_id}", tags=["agents"], response_model=Message) - async def update_message( - agent_id: str, - message_id: str, - request: UpdateMessage = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Update the details of a message associated with an agent. - """ - assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}" - return server.update_agent_message(agent_id=agent_id, request=request) - - return router diff --git a/memgpt/server/rest_api/block/index.py b/memgpt/server/rest_api/block/index.py deleted file mode 100644 index d306be99..00000000 --- a/memgpt/server/rest_api/block/index.py +++ /dev/null @@ -1,73 +0,0 @@ -from functools import partial -from typing import List, Optional - -from fastapi import APIRouter, Body, Depends, HTTPException, Query - -from memgpt.schemas.block import Block, CreateBlock -from memgpt.schemas.block import Human as HumanModel # TODO: modify -from memgpt.schemas.block import UpdateBlock -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -def setup_block_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/blocks", tags=["block"], response_model=List[Block]) - async def list_blocks( - user_id: str = Depends(get_current_user_with_server), - # query parameters - label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"), - templates_only: bool = Query(True, description="Whether to include only templates"), - name: Optional[str] = Query(None, description="Name of the block"), - ): - # Clear the interface - interface.clear() - blocks = server.get_blocks(user_id=user_id, label=label, template=templates_only, name=name) - if blocks is None: - return [] - return blocks - - @router.post("/blocks", tags=["block"], response_model=Block) - async def create_block( - request: CreateBlock = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - request.user_id = user_id # TODO: remove? - return server.create_block(user_id=user_id, request=request) - - @router.post("/blocks/{block_id}", tags=["block"], response_model=Block) - async def update_block( - block_id: str, - request: UpdateBlock = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - # TODO: should this be in the param or the POST data? - assert block_id == request.id - return server.update_block(request) - - @router.delete("/blocks/{block_id}", tags=["block"], response_model=Block) - async def delete_block( - block_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - return server.delete_block(block_id=block_id) - - @router.get("/blocks/{block_id}", tags=["block"], response_model=Block) - async def get_block( - block_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - block = server.get_block(block_id=block_id) - if block is None: - raise HTTPException(status_code=404, detail="Block not found") - return block - - return router diff --git a/memgpt/server/rest_api/config/index.py b/memgpt/server/rest_api/config/index.py deleted file mode 100644 index 79b2dc22..00000000 --- a/memgpt/server/rest_api/config/index.py +++ /dev/null @@ -1,40 +0,0 @@ -from functools import partial -from typing import List - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, Field - -from memgpt.schemas.embedding_config import EmbeddingConfig -from memgpt.schemas.llm_config import LLMConfig -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -class ConfigResponse(BaseModel): - config: dict = Field(..., description="The server configuration object.") - defaults: dict = Field(..., description="The defaults for the configuration.") - - -def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/config/llm", tags=["config"], response_model=List[LLMConfig]) - def get_llm_configs(user_id: str = Depends(get_current_user_with_server)): - """ - Retrieve the base configuration for the server. - """ - interface.clear() - return [server.server_llm_config] - - @router.get("/config/embedding", tags=["config"], response_model=List[EmbeddingConfig]) - def get_embedding_configs(user_id: str = Depends(get_current_user_with_server)): - """ - Retrieve the base configuration for the server. - """ - interface.clear() - return [server.server_embedding_config] - - return router diff --git a/memgpt/server/rest_api/jobs/index.py b/memgpt/server/rest_api/jobs/index.py deleted file mode 100644 index 0a2cb083..00000000 --- a/memgpt/server/rest_api/jobs/index.py +++ /dev/null @@ -1,41 +0,0 @@ -from functools import partial -from typing import List - -from fastapi import APIRouter, Depends - -from memgpt.schemas.job import Job -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -def setup_jobs_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/jobs", tags=["jobs"], response_model=List[Job]) - async def list_jobs( - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - - # TODO: add filtering by status - return server.list_jobs(user_id=user_id) - - @router.get("/jobs/active", tags=["jobs"], response_model=List[Job]) - async def list_active_jobs( - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - return server.list_active_jobs(user_id=user_id) - - @router.get("/jobs/{job_id}", tags=["jobs"], response_model=Job) - async def get_job( - job_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - interface.clear() - return server.get_job(job_id=job_id) - - return router diff --git a/memgpt/server/rest_api/models/index.py b/memgpt/server/rest_api/models/index.py deleted file mode 100644 index 72755545..00000000 --- a/memgpt/server/rest_api/models/index.py +++ /dev/null @@ -1,38 +0,0 @@ -from functools import partial -from typing import List - -from fastapi import APIRouter -from pydantic import BaseModel, Field - -from memgpt.schemas.llm_config import LLMConfig -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -class ListModelsResponse(BaseModel): - models: List[LLMConfig] = Field(..., description="List of model configurations.") - - -def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str): - partial(partial(get_current_user, server), password) - - @router.get("/models", tags=["models"], response_model=ListModelsResponse) - async def list_models(): - # Clear the interface - interface.clear() - - # currently, the server only supports one model, however this may change in the future - llm_config = LLMConfig( - model=server.server_llm_config.model, - model_endpoint=server.server_llm_config.model_endpoint, - model_endpoint_type=server.server_llm_config.model_endpoint_type, - model_wrapper=server.server_llm_config.model_wrapper, - context_window=server.server_llm_config.context_window, - ) - - return ListModelsResponse(models=[llm_config]) - - return router diff --git a/memgpt/server/rest_api/openai_assistants/assistants.py b/memgpt/server/rest_api/openai_assistants/assistants.py deleted file mode 100644 index bd92ea1a..00000000 --- a/memgpt/server/rest_api/openai_assistants/assistants.py +++ /dev/null @@ -1,488 +0,0 @@ -import uuid -from typing import List, Optional - -from fastapi import APIRouter, Body, HTTPException, Path, Query -from pydantic import BaseModel, Field - -from memgpt.constants import DEFAULT_PRESET -from memgpt.schemas.message import Message -from memgpt.schemas.openai.openai import ( - AssistantFile, - MessageFile, - MessageRoleType, - OpenAIAssistant, - OpenAIMessage, - OpenAIRun, - OpenAIRunStep, - OpenAIThread, - Text, - ToolCall, - ToolCallOutput, -) -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer -from memgpt.utils import get_utc_time - -router = APIRouter() - - -class CreateAssistantRequest(BaseModel): - model: str = Field(..., description="The model to use for the assistant.") - name: str = Field(..., description="The name of the assistant.") - description: str = Field(None, description="The description of the assistant.") - instructions: str = Field(..., description="The instructions for the assistant.") - tools: List[str] = Field(None, description="The tools used by the assistant.") - file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.") - metadata: dict = Field(None, description="Metadata associated with the assistant.") - - # memgpt-only (not openai) - embedding_model: str = Field(None, description="The model to use for the assistant.") - - ## TODO: remove - # user_id: str = Field(..., description="The unique identifier of the user.") - - -class CreateThreadRequest(BaseModel): - messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.") - metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.") - - # memgpt-only - assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)") - - -class ModifyThreadRequest(BaseModel): - metadata: dict = Field(None, description="Metadata associated with the thread.") - - -class ModifyMessageRequest(BaseModel): - metadata: dict = Field(None, description="Metadata associated with the message.") - - -class ModifyRunRequest(BaseModel): - metadata: dict = Field(None, description="Metadata associated with the run.") - - -class CreateMessageRequest(BaseModel): - role: str = Field(..., description="Role of the message sender (either 'user' or 'system')") - content: str = Field(..., description="The message content to be processed by the agent.") - file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.") - metadata: Optional[dict] = Field(None, description="Metadata associated with the message.") - - -class UserMessageRequest(BaseModel): - user_id: str = Field(..., description="The unique identifier of the user.") - agent_id: str = Field(..., description="The unique identifier of the agent.") - message: str = Field(..., description="The message content to be processed by the agent.") - stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") - role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") - - -class UserMessageResponse(BaseModel): - messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.") - - -class GetAgentMessagesRequest(BaseModel): - user_id: str = Field(..., description="The unique identifier of the user.") - agent_id: str = Field(..., description="The unique identifier of the agent.") - start: int = Field(..., description="Message index to start on (reverse chronological).") - count: int = Field(..., description="How many messages to retrieve.") - - -class ListMessagesResponse(BaseModel): - messages: List[OpenAIMessage] = Field(..., description="List of message objects.") - - -class CreateAssistantFileRequest(BaseModel): - file_id: str = Field(..., description="The unique identifier of the file.") - - -class CreateRunRequest(BaseModel): - assistant_id: str = Field(..., description="The unique identifier of the assistant.") - model: Optional[str] = Field(None, description="The model used by the run.") - instructions: str = Field(..., description="The instructions for the run.") - additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.") - tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).") - metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") - - -class CreateThreadRunRequest(BaseModel): - assistant_id: str = Field(..., description="The unique identifier of the assistant.") - thread: OpenAIThread = Field(..., description="The thread to run.") - model: str = Field(..., description="The model used by the run.") - instructions: str = Field(..., description="The instructions for the run.") - tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).") - metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") - - -class DeleteAssistantResponse(BaseModel): - id: str = Field(..., description="The unique identifier of the agent.") - object: str = "assistant.deleted" - deleted: bool = Field(..., description="Whether the agent was deleted.") - - -class DeleteAssistantFileResponse(BaseModel): - id: str = Field(..., description="The unique identifier of the file.") - object: str = "assistant.file.deleted" - deleted: bool = Field(..., description="Whether the file was deleted.") - - -class DeleteThreadResponse(BaseModel): - id: str = Field(..., description="The unique identifier of the agent.") - object: str = "thread.deleted" - deleted: bool = Field(..., description="Whether the agent was deleted.") - - -class SubmitToolOutputsToRunRequest(BaseModel): - tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.") - - -# TODO: implement mechanism for creating/authenticating users associated with a bearer token -def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface): - # create assistant (MemGPT agent) - @router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant) - def create_assistant(request: CreateAssistantRequest = Body(...)): - # TODO: create preset - return OpenAIAssistant( - id=DEFAULT_PRESET, - name="default_preset", - description=request.description, - created_at=int(get_utc_time().timestamp()), - model=request.model, - instructions=request.instructions, - tools=request.tools, - file_ids=request.file_ids, - metadata=request.metadata, - ) - - @router.post("/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile) - def create_assistant_file( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - request: CreateAssistantFileRequest = Body(...), - ): - # TODO: add file to assistant - return AssistantFile( - id=request.file_id, - created_at=int(get_utc_time().timestamp()), - assistant_id=assistant_id, - ) - - @router.get("/assistants", tags=["assistants"], response_model=List[OpenAIAssistant]) - def list_assistants( - limit: int = Query(1000, description="How many assistants to retrieve."), - order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."), - after: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - before: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - ): - # TODO: implement list assistants (i.e. list available MemGPT presets) - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile]) - def list_assistant_files( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - limit: int = Query(1000, description="How many files to retrieve."), - order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."), - after: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - before: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - ): - # TODO: list attached data sources to preset - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant) - def retrieve_assistant( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - ): - # TODO: get and return preset - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile) - def retrieve_assistant_file( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - file_id: str = Path(..., description="The unique identifier of the file."), - ): - # TODO: return data source attached to preset - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant) - def modify_assistant( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - request: CreateAssistantRequest = Body(...), - ): - # TODO: modify preset - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.delete("/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse) - def delete_assistant( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - ): - # TODO: delete preset - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.delete("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse) - def delete_assistant_file( - assistant_id: str = Path(..., description="The unique identifier of the assistant."), - file_id: str = Path(..., description="The unique identifier of the file."), - ): - # TODO: delete source on preset - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads", tags=["threads"], response_model=OpenAIThread) - def create_thread(request: CreateThreadRequest = Body(...)): - # TODO: use requests.description and requests.metadata fields - # TODO: handle requests.file_ids and requests.tools - # TODO: eventually allow request to override embedding/llm model - - print("Create thread/agent", request) - # create a memgpt agent - agent_state = server.create_agent( - user_id=user_id, - ) - # TODO: insert messages into recall memory - return OpenAIThread( - id=str(agent_state.id), - created_at=int(agent_state.created_at.timestamp()), - ) - - @router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread) - def retrieve_thread( - thread_id: str = Path(..., description="The unique identifier of the thread."), - ): - agent = server.get_agent(uuid.UUID(thread_id)) - return OpenAIThread( - id=str(agent.id), - created_at=int(agent.created_at.timestamp()), - ) - - @router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread) - def modify_thread( - thread_id: str = Path(..., description="The unique identifier of the thread."), - request: ModifyThreadRequest = Body(...), - ): - # TODO: add agent metadata so this can be modified - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.delete("/threads/{thread_id}", tags=["threads"], response_model=DeleteThreadResponse) - def delete_thread( - thread_id: str = Path(..., description="The unique identifier of the thread."), - ): - # TODO: delete agent - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage) - def create_message( - thread_id: str = Path(..., description="The unique identifier of the thread."), - request: CreateMessageRequest = Body(...), - ): - agent_id = uuid.UUID(thread_id) - # create message object - message = Message( - user_id=user_id, - agent_id=agent_id, - role=request.role, - text=request.content, - ) - agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id) - # add message to agent - agent._append_to_messages([message]) - - openai_message = OpenAIMessage( - id=str(message.id), - created_at=int(message.created_at.timestamp()), - content=[Text(text=message.text)], - role=message.role, - thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - # file_ids=message.file_ids, - # metadata=message.metadata, - ) - return openai_message - - @router.get("/threads/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse) - def list_messages( - thread_id: str = Path(..., description="The unique identifier of the thread."), - limit: int = Query(1000, description="How many messages to retrieve."), - order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."), - after: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - before: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - ): - after_uuid = uuid.UUID(after) if before else None - before_uuid = uuid.UUID(before) if before else None - agent_id = uuid.UUID(thread_id) - reverse = True if (order == "desc") else False - cursor, json_messages = server.get_agent_recall_cursor( - user_id=user_id, - agent_id=agent_id, - limit=limit, - after=after_uuid, - before=before_uuid, - order_by="created_at", - reverse=reverse, - ) - print(json_messages[0]["text"]) - # convert to openai style messages - openai_messages = [ - OpenAIMessage( - id=str(message["id"]), - created_at=int(message["created_at"].timestamp()), - content=[Text(text=message["text"])], - role=message["role"], - thread_id=str(message["agent_id"]), - assistant_id=DEFAULT_PRESET, # TODO: update this - # file_ids=message.file_ids, - # metadata=message.metadata, - ) - for message in json_messages - ] - print("MESSAGES", openai_messages) - # TODO: cast back to message objects - return ListMessagesResponse(messages=openai_messages) - - router.get("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage) - - def retrieve_message( - thread_id: str = Path(..., description="The unique identifier of the thread."), - message_id: str = Path(..., description="The unique identifier of the message."), - ): - message_id = uuid.UUID(message_id) - agent_id = uuid.UUID(thread_id) - message = server.get_agent_message(agent_id, message_id) - return OpenAIMessage( - id=str(message.id), - created_at=int(message.created_at.timestamp()), - content=[Text(text=message.text)], - role=message.role, - thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - # file_ids=message.file_ids, - # metadata=message.metadata, - ) - - @router.get("/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile) - def retrieve_message_file( - thread_id: str = Path(..., description="The unique identifier of the thread."), - message_id: str = Path(..., description="The unique identifier of the message."), - file_id: str = Path(..., description="The unique identifier of the file."), - ): - # TODO: implement? - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage) - def modify_message( - thread_id: str = Path(..., description="The unique identifier of the thread."), - message_id: str = Path(..., description="The unique identifier of the message."), - request: ModifyMessageRequest = Body(...), - ): - # TODO: add metada field to message so this can be modified - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun) - def create_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - request: CreateRunRequest = Body(...), - ): - # TODO: add request.instructions as a message? - agent_id = uuid.UUID(thread_id) - # TODO: override preset of agent with request.assistant_id - agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id) - agent.step(user_message=None) # already has messages added - run_id = str(uuid.uuid4()) - create_time = int(get_utc_time().timestamp()) - return OpenAIRun( - id=run_id, - created_at=create_time, - thread_id=str(agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this - status="completed", # TODO: eventaully allow offline execution - expires_at=create_time, - model=agent.agent_state.llm_config.model, - instructions=request.instructions, - ) - - @router.post("/threads/runs", tags=["runs"], response_model=OpenAIRun) - def create_thread_and_run( - request: CreateThreadRunRequest = Body(...), - ): - # TODO: add a bunch of messages and execute - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/threads/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun]) - def list_runs( - thread_id: str = Path(..., description="The unique identifier of the thread."), - limit: int = Query(1000, description="How many runs to retrieve."), - order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."), - after: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - before: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - ): - # TODO: store run information in a DB so it can be returned here - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/threads/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep]) - def list_run_steps( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - limit: int = Query(1000, description="How many run steps to retrieve."), - order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."), - after: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - before: str = Query( - None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list." - ), - ): - # TODO: store run information in a DB so it can be returned here - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun) - def retrieve_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - ): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.get("/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep) - def retrieve_run_step( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - step_id: str = Path(..., description="The unique identifier of the run step."), - ): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun) - def modify_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - request: ModifyRunRequest = Body(...), - ): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun) - def submit_tool_outputs_to_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - request: SubmitToolOutputsToRunRequest = Body(...), - ): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - @router.post("/threads/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun) - def cancel_run( - thread_id: str = Path(..., description="The unique identifier of the thread."), - run_id: str = Path(..., description="The unique identifier of the run."), - ): - raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") - - return router diff --git a/memgpt/server/rest_api/openai_chat_completions/chat_completions.py b/memgpt/server/rest_api/openai_chat_completions/chat_completions.py deleted file mode 100644 index fe3f71e3..00000000 --- a/memgpt/server/rest_api/openai_chat_completions/chat_completions.py +++ /dev/null @@ -1,129 +0,0 @@ -import json -import uuid -from functools import partial - -from fastapi import APIRouter, Body, Depends, HTTPException - -# from memgpt.schemas.message import Message -from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest -from memgpt.schemas.openai.chat_completion_response import ( - ChatCompletionResponse, - Choice, - Message, - UsageStatistics, -) -from memgpt.server.rest_api.agents.message import send_message_to_agent -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer -from memgpt.utils import get_utc_time - -router = APIRouter() - - -def setup_openai_chat_completions_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.post("/chat/completions", tags=["chat_completions"], response_model=ChatCompletionResponse) - async def create_chat_completion( - request: ChatCompletionRequest = Body(...), - user_id: uuid.UUID = Depends(get_current_user_with_server), - ): - """Send a message to a MemGPT agent via a /chat/completions request - - The bearer token will be used to identify the user. - The 'user' field in the request should be set to the agent ID. - """ - agent_id = request.user - if agent_id is None: - raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") - try: - agent_id = uuid.UUID(agent_id) - except: - raise HTTPException(status_code=400, detail="agent_id (in the 'user' field) must be a valid UUID") - - messages = request.messages - if messages is None: - raise HTTPException(status_code=400, detail="'messages' field must not be empty") - if len(messages) > 1: - raise HTTPException(status_code=400, detail="'messages' field must be a list of length 1") - if messages[0].role != "user": - raise HTTPException(status_code=400, detail="'messages[0].role' must be a 'user'") - - input_message = request.messages[0] - if request.stream: - print("Starting streaming OpenAI proxy response") - - return await send_message_to_agent( - server=server, - agent_id=agent_id, - user_id=user_id, - role=input_message.role, - message=str(input_message.content), - stream_legacy=False, - # Turn streaming ON - stream_steps=True, - stream_tokens=True, - # Turn on ChatCompletion mode (eg remaps send_message to content) - chat_completion_mode=True, - ) - - else: - print("Starting non-streaming OpenAI proxy response") - - response_messages = await send_message_to_agent( - server=server, - agent_id=agent_id, - user_id=user_id, - role=input_message.role, - message=str(input_message.content), - stream_legacy=False, - # Turn streaming OFF - stream_steps=False, - stream_tokens=False, - ) - # print(response_messages) - - # Concatenate all send_message outputs together - id = "" - visible_message_str = "" - created_at = None - for memgpt_msg in response_messages.messages: - if "function_call" in memgpt_msg: - memgpt_function_call = memgpt_msg["function_call"] - if "name" in memgpt_function_call and memgpt_function_call["name"] == "send_message": - try: - memgpt_function_call_args = json.loads(memgpt_function_call["arguments"]) - visible_message_str += memgpt_function_call_args["message"] - id = memgpt_function_call["id"] - created_at = memgpt_msg["date"] - except: - print(f"Failed to parse MemGPT message: {str(memgpt_function_call)}") - else: - print(f"Skipping function_call: {str(memgpt_function_call)}") - else: - print(f"Skipping message: {str(memgpt_msg)}") - - response = ChatCompletionResponse( - id=id, - created=created_at if created_at else get_utc_time(), - choices=[ - Choice( - finish_reason="stop", - index=0, - message=Message( - role="assistant", - content=visible_message_str, - ), - ) - ], - # TODO add real usage - usage=UsageStatistics( - completion_tokens=0, - prompt_tokens=0, - total_tokens=0, - ), - ) - return response - - return router diff --git a/memgpt/server/rest_api/personas/__init__.py b/memgpt/server/rest_api/personas/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/memgpt/server/rest_api/personas/index.py b/memgpt/server/rest_api/personas/index.py deleted file mode 100644 index abf9374d..00000000 --- a/memgpt/server/rest_api/personas/index.py +++ /dev/null @@ -1,70 +0,0 @@ -import uuid -from functools import partial -from typing import List - -from fastapi import APIRouter, Body, Depends, HTTPException -from pydantic import BaseModel, Field - -from memgpt.schemas.block import Persona as PersonaModel # TODO: modify -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -class ListPersonasResponse(BaseModel): - personas: List[PersonaModel] = Field(..., description="List of persona configurations.") - - -class CreatePersonaRequest(BaseModel): - text: str = Field(..., description="The persona text.") - name: str = Field(..., description="The name of the persona.") - - -def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/personas", tags=["personas"], response_model=ListPersonasResponse) - async def list_personas( - user_id: uuid.UUID = Depends(get_current_user_with_server), - ): - # Clear the interface - interface.clear() - - personas = server.ms.list_personas(user_id=user_id) - return ListPersonasResponse(personas=personas) - - @router.post("/personas", tags=["personas"], response_model=PersonaModel) - async def create_persona( - request: CreatePersonaRequest = Body(...), - user_id: uuid.UUID = Depends(get_current_user_with_server), - ): - # TODO: disallow duplicate names for personas - interface.clear() - new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id) - persona_id = new_persona.id - server.ms.create_persona(new_persona) - return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id) - - @router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel) - async def delete_persona( - persona_name: str, - user_id: uuid.UUID = Depends(get_current_user_with_server), - ): - interface.clear() - persona = server.ms.delete_persona(persona_name, user_id=user_id) - return persona - - @router.get("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel) - async def get_persona( - persona_name: str, - user_id: uuid.UUID = Depends(get_current_user_with_server), - ): - interface.clear() - persona = server.ms.get_persona(persona_name, user_id=user_id) - if persona is None: - raise HTTPException(status_code=404, detail="Persona not found") - return persona - - return router diff --git a/memgpt/server/rest_api/agents/__init__.py b/memgpt/server/rest_api/routers/__init__.py similarity index 100% rename from memgpt/server/rest_api/agents/__init__.py rename to memgpt/server/rest_api/routers/__init__.py diff --git a/memgpt/server/rest_api/block/__init__.py b/memgpt/server/rest_api/routers/openai/__init__.py similarity index 100% rename from memgpt/server/rest_api/block/__init__.py rename to memgpt/server/rest_api/routers/openai/__init__.py diff --git a/memgpt/server/rest_api/config/__init__.py b/memgpt/server/rest_api/routers/openai/assistants/__init__.py similarity index 100% rename from memgpt/server/rest_api/config/__init__.py rename to memgpt/server/rest_api/routers/openai/assistants/__init__.py diff --git a/memgpt/server/rest_api/routers/openai/assistants/assistants.py b/memgpt/server/rest_api/routers/openai/assistants/assistants.py new file mode 100644 index 00000000..111b38a8 --- /dev/null +++ b/memgpt/server/rest_api/routers/openai/assistants/assistants.py @@ -0,0 +1,115 @@ +from typing import List + +from fastapi import APIRouter, Body, HTTPException, Path, Query + +from memgpt.constants import DEFAULT_PRESET +from memgpt.schemas.openai.openai import AssistantFile, OpenAIAssistant +from memgpt.server.rest_api.routers.openai.assistants.schemas import ( + CreateAssistantFileRequest, + CreateAssistantRequest, + DeleteAssistantFileResponse, + DeleteAssistantResponse, +) +from memgpt.utils import get_utc_time + +router = APIRouter() + + +# TODO: implement mechanism for creating/authenticating users associated with a bearer token +router = APIRouter(prefix="/v1/assistants", tags=["assistants"]) + + +# create assistant (MemGPT agent) +@router.post("/", response_model=OpenAIAssistant) +def create_assistant(request: CreateAssistantRequest = Body(...)): + # TODO: create preset + return OpenAIAssistant( + id=DEFAULT_PRESET, + name="default_preset", + description=request.description, + created_at=int(get_utc_time().timestamp()), + model=request.model, + instructions=request.instructions, + tools=request.tools, + file_ids=request.file_ids, + metadata=request.metadata, + ) + + +@router.post("/{assistant_id}/files", response_model=AssistantFile) +def create_assistant_file( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + request: CreateAssistantFileRequest = Body(...), +): + # TODO: add file to assistant + return AssistantFile( + id=request.file_id, + created_at=int(get_utc_time().timestamp()), + assistant_id=assistant_id, + ) + + +@router.get("/", response_model=List[OpenAIAssistant]) +def list_assistants( + limit: int = Query(1000, description="How many assistants to retrieve."), + order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: implement list assistants (i.e. list available MemGPT presets) + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{assistant_id}/files", response_model=List[AssistantFile]) +def list_assistant_files( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + limit: int = Query(1000, description="How many files to retrieve."), + order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: list attached data sources to preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{assistant_id}", response_model=OpenAIAssistant) +def retrieve_assistant( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), +): + # TODO: get and return preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{assistant_id}/files/{file_id}", response_model=AssistantFile) +def retrieve_assistant_file( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + file_id: str = Path(..., description="The unique identifier of the file."), +): + # TODO: return data source attached to preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{assistant_id}", response_model=OpenAIAssistant) +def modify_assistant( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + request: CreateAssistantRequest = Body(...), +): + # TODO: modify preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.delete("/{assistant_id}", response_model=DeleteAssistantResponse) +def delete_assistant( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), +): + # TODO: delete preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.delete("/{assistant_id}/files/{file_id}", response_model=DeleteAssistantFileResponse) +def delete_assistant_file( + assistant_id: str = Path(..., description="The unique identifier of the assistant."), + file_id: str = Path(..., description="The unique identifier of the file."), +): + # TODO: delete source on preset + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") diff --git a/memgpt/server/rest_api/routers/openai/assistants/schemas.py b/memgpt/server/rest_api/routers/openai/assistants/schemas.py new file mode 100644 index 00000000..9ab1860b --- /dev/null +++ b/memgpt/server/rest_api/routers/openai/assistants/schemas.py @@ -0,0 +1,121 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + +from memgpt.schemas.openai.openai import ( + MessageRoleType, + OpenAIMessage, + OpenAIThread, + ToolCall, + ToolCallOutput, +) + + +class CreateAssistantRequest(BaseModel): + model: str = Field(..., description="The model to use for the assistant.") + name: str = Field(..., description="The name of the assistant.") + description: str = Field(None, description="The description of the assistant.") + instructions: str = Field(..., description="The instructions for the assistant.") + tools: List[str] = Field(None, description="The tools used by the assistant.") + file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.") + metadata: dict = Field(None, description="Metadata associated with the assistant.") + + # memgpt-only (not openai) + embedding_model: str = Field(None, description="The model to use for the assistant.") + + ## TODO: remove + # user_id: str = Field(..., description="The unique identifier of the user.") + + +class CreateThreadRequest(BaseModel): + messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.") + + # memgpt-only + assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)") + + +class ModifyThreadRequest(BaseModel): + metadata: dict = Field(None, description="Metadata associated with the thread.") + + +class ModifyMessageRequest(BaseModel): + metadata: dict = Field(None, description="Metadata associated with the message.") + + +class ModifyRunRequest(BaseModel): + metadata: dict = Field(None, description="Metadata associated with the run.") + + +class CreateMessageRequest(BaseModel): + role: str = Field(..., description="Role of the message sender (either 'user' or 'system')") + content: str = Field(..., description="The message content to be processed by the agent.") + file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the message.") + + +class UserMessageRequest(BaseModel): + user_id: str = Field(..., description="The unique identifier of the user.") + agent_id: str = Field(..., description="The unique identifier of the agent.") + message: str = Field(..., description="The message content to be processed by the agent.") + stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") + role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") + + +class UserMessageResponse(BaseModel): + messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.") + + +class GetAgentMessagesRequest(BaseModel): + user_id: str = Field(..., description="The unique identifier of the user.") + agent_id: str = Field(..., description="The unique identifier of the agent.") + start: int = Field(..., description="Message index to start on (reverse chronological).") + count: int = Field(..., description="How many messages to retrieve.") + + +class ListMessagesResponse(BaseModel): + messages: List[OpenAIMessage] = Field(..., description="List of message objects.") + + +class CreateAssistantFileRequest(BaseModel): + file_id: str = Field(..., description="The unique identifier of the file.") + + +class CreateRunRequest(BaseModel): + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + model: Optional[str] = Field(None, description="The model used by the run.") + instructions: str = Field(..., description="The instructions for the run.") + additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.") + tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).") + metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") + + +class CreateThreadRunRequest(BaseModel): + assistant_id: str = Field(..., description="The unique identifier of the assistant.") + thread: OpenAIThread = Field(..., description="The thread to run.") + model: str = Field(..., description="The model used by the run.") + instructions: str = Field(..., description="The instructions for the run.") + tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).") + metadata: Optional[dict] = Field(None, description="Metadata associated with the run.") + + +class DeleteAssistantResponse(BaseModel): + id: str = Field(..., description="The unique identifier of the agent.") + object: str = "assistant.deleted" + deleted: bool = Field(..., description="Whether the agent was deleted.") + + +class DeleteAssistantFileResponse(BaseModel): + id: str = Field(..., description="The unique identifier of the file.") + object: str = "assistant.file.deleted" + deleted: bool = Field(..., description="Whether the file was deleted.") + + +class DeleteThreadResponse(BaseModel): + id: str = Field(..., description="The unique identifier of the agent.") + object: str = "thread.deleted" + deleted: bool = Field(..., description="Whether the agent was deleted.") + + +class SubmitToolOutputsToRunRequest(BaseModel): + tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.") diff --git a/memgpt/server/rest_api/routers/openai/assistants/threads.py b/memgpt/server/rest_api/routers/openai/assistants/threads.py new file mode 100644 index 00000000..86c6f621 --- /dev/null +++ b/memgpt/server/rest_api/routers/openai/assistants/threads.py @@ -0,0 +1,336 @@ +import uuid +from typing import TYPE_CHECKING, List + +from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query + +from memgpt.constants import DEFAULT_PRESET +from memgpt.schemas.agent import CreateAgent +from memgpt.schemas.enums import MessageRole +from memgpt.schemas.message import Message +from memgpt.schemas.openai.openai import ( + MessageFile, + OpenAIMessage, + OpenAIRun, + OpenAIRunStep, + OpenAIThread, + Text, +) +from memgpt.server.rest_api.routers.openai.assistants.schemas import ( + CreateMessageRequest, + CreateRunRequest, + CreateThreadRequest, + CreateThreadRunRequest, + DeleteThreadResponse, + ListMessagesResponse, + ModifyMessageRequest, + ModifyRunRequest, + ModifyThreadRequest, + OpenAIThread, + SubmitToolOutputsToRunRequest, +) +from memgpt.server.rest_api.utils import get_memgpt_server +from memgpt.server.server import SyncServer + +if TYPE_CHECKING: + from memgpt.utils import get_utc_time + + +# TODO: implement mechanism for creating/authenticating users associated with a bearer token +router = APIRouter(prefix="/v1/threads", tags=["threads"]) + + +@router.post("/", response_model=OpenAIThread) +def create_thread( + request: CreateThreadRequest = Body(...), + server: SyncServer = Depends(get_memgpt_server), +): + # TODO: use requests.description and requests.metadata fields + # TODO: handle requests.file_ids and requests.tools + # TODO: eventually allow request to override embedding/llm model + actor = server.get_current_user() + + print("Create thread/agent", request) + # create a memgpt agent + agent_state = server.create_agent( + request=CreateAgent(), + user_id=actor.id, + ) + # TODO: insert messages into recall memory + return OpenAIThread( + id=str(agent_state.id), + created_at=int(agent_state.created_at.timestamp()), + metadata={}, # TODO add metadata? + ) + + +@router.get("/{thread_id}", response_model=OpenAIThread) +def retrieve_thread( + thread_id: str = Path(..., description="The unique identifier of the thread."), + server: SyncServer = Depends(get_memgpt_server), +): + actor = server.get_current_user() + agent = server.get_agent(user_id=actor.id, agent_id=thread_id) + assert agent is not None + return OpenAIThread( + id=str(agent.id), + created_at=int(agent.created_at.timestamp()), + metadata={}, # TODO add metadata? + ) + + +@router.get("/{thread_id}", response_model=OpenAIThread) +def modify_thread( + thread_id: str = Path(..., description="The unique identifier of the thread."), + request: ModifyThreadRequest = Body(...), +): + # TODO: add agent metadata so this can be modified + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.delete("/{thread_id}", response_model=DeleteThreadResponse) +def delete_thread( + thread_id: str = Path(..., description="The unique identifier of the thread."), +): + # TODO: delete agent + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage) +def create_message( + thread_id: str = Path(..., description="The unique identifier of the thread."), + request: CreateMessageRequest = Body(...), + server: SyncServer = Depends(get_memgpt_server), +): + actor = server.get_current_user() + agent_id = thread_id + # create message object + message = Message( + user_id=actor.id, + agent_id=agent_id, + role=MessageRole(request.role), + text=request.content, + model=None, + tool_calls=None, + tool_call_id=None, + name=None, + ) + agent = server._get_or_load_agent(agent_id=agent_id) + # add message to agent + agent._append_to_messages([message]) + + openai_message = OpenAIMessage( + id=str(message.id), + created_at=int(message.created_at.timestamp()), + content=[Text(text=(message.text if message.text else ""))], + role=message.role, + thread_id=str(message.agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + # TODO(sarah) fill in? + run_id=None, + file_ids=None, + metadata=None, + # file_ids=message.file_ids, + # metadata=message.metadata, + ) + return openai_message + + +@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse) +def list_messages( + thread_id: str = Path(..., description="The unique identifier of the thread."), + limit: int = Query(1000, description="How many messages to retrieve."), + order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + server: SyncServer = Depends(get_memgpt_server), +): + actor = server.get_current_user() + after_uuid = after if before else None + before_uuid = before if before else None + agent_id = thread_id + reverse = True if (order == "desc") else False + json_messages = server.get_agent_recall_cursor( + user_id=actor.id, + agent_id=agent_id, + limit=limit, + after=after_uuid, + before=before_uuid, + order_by="created_at", + reverse=reverse, + return_message_object=True, + ) + assert isinstance(json_messages, List) + assert all([isinstance(message, Message) for message in json_messages]) + assert isinstance(json_messages[0], Message) + print(json_messages[0].text) + # convert to openai style messages + openai_messages = [] + for message in json_messages: + assert isinstance(message, Message) + openai_messages.append( + OpenAIMessage( + id=str(message.id), + created_at=int(message.created_at.timestamp()), + content=[Text(text=(message.text if message.text else ""))], + role=str(message.role), + thread_id=str(message.agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + # TODO(sarah) fill in? + run_id=None, + file_ids=None, + metadata=None, + # file_ids=message.file_ids, + # metadata=message.metadata, + ) + ) + print("MESSAGES", openai_messages) + # TODO: cast back to message objects + return ListMessagesResponse(messages=openai_messages) + + +@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage) +def retrieve_message( + thread_id: str = Path(..., description="The unique identifier of the thread."), + message_id: str = Path(..., description="The unique identifier of the message."), + server: SyncServer = Depends(get_memgpt_server), +): + agent_id = thread_id + message = server.get_agent_message(agent_id=agent_id, message_id=message_id) + assert message is not None + return OpenAIMessage( + id=message_id, + created_at=int(message.created_at.timestamp()), + content=[Text(text=(message.text if message.text else ""))], + role=message.role, + thread_id=str(message.agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + # TODO(sarah) fill in? + run_id=None, + file_ids=None, + metadata=None, + # file_ids=message.file_ids, + # metadata=message.metadata, + ) + + +@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile) +def retrieve_message_file( + thread_id: str = Path(..., description="The unique identifier of the thread."), + message_id: str = Path(..., description="The unique identifier of the message."), + file_id: str = Path(..., description="The unique identifier of the file."), +): + # TODO: implement? + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage) +def modify_message( + thread_id: str = Path(..., description="The unique identifier of the thread."), + message_id: str = Path(..., description="The unique identifier of the message."), + request: ModifyMessageRequest = Body(...), +): + # TODO: add metada field to message so this can be modified + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun) +def create_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + request: CreateRunRequest = Body(...), + server: SyncServer = Depends(get_memgpt_server), +): + server.get_current_user() + + # TODO: add request.instructions as a message? + agent_id = thread_id + # TODO: override preset of agent with request.assistant_id + agent = server._get_or_load_agent(agent_id=agent_id) + agent.step(user_message=None) # already has messages added + run_id = str(uuid.uuid4()) + create_time = int(get_utc_time().timestamp()) + return OpenAIRun( + id=run_id, + created_at=create_time, + thread_id=str(agent_id), + assistant_id=DEFAULT_PRESET, # TODO: update this + status="completed", # TODO: eventaully allow offline execution + expires_at=create_time, + model=agent.agent_state.llm_config.model, + instructions=request.instructions, + ) + + +@router.post("/runs", tags=["runs"], response_model=OpenAIRun) +def create_thread_and_run( + request: CreateThreadRunRequest = Body(...), +): + # TODO: add a bunch of messages and execute + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun]) +def list_runs( + thread_id: str = Path(..., description="The unique identifier of the thread."), + limit: int = Query(1000, description="How many runs to retrieve."), + order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: store run information in a DB so it can be returned here + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep]) +def list_run_steps( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + limit: int = Query(1000, description="How many run steps to retrieve."), + order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."), + after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), + before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), +): + # TODO: store run information in a DB so it can be returned here + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun) +def retrieve_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep) +def retrieve_run_step( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + step_id: str = Path(..., description="The unique identifier of the run step."), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun) +def modify_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + request: ModifyRunRequest = Body(...), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun) +def submit_tool_outputs_to_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), + request: SubmitToolOutputsToRunRequest = Body(...), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") + + +@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun) +def cancel_run( + thread_id: str = Path(..., description="The unique identifier of the thread."), + run_id: str = Path(..., description="The unique identifier of the run."), +): + raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)") diff --git a/memgpt/server/rest_api/models/__init__.py b/memgpt/server/rest_api/routers/openai/chat_completions/__init__.py similarity index 100% rename from memgpt/server/rest_api/models/__init__.py rename to memgpt/server/rest_api/routers/openai/chat_completions/__init__.py diff --git a/memgpt/server/rest_api/routers/openai/chat_completions/chat_completions.py b/memgpt/server/rest_api/routers/openai/chat_completions/chat_completions.py new file mode 100644 index 00000000..5d5e0024 --- /dev/null +++ b/memgpt/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -0,0 +1,131 @@ +import json +from typing import TYPE_CHECKING + +from fastapi import APIRouter, Body, Depends, HTTPException + +from memgpt.schemas.enums import MessageRole +from memgpt.schemas.memgpt_message import FunctionCall, MemGPTMessage +from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest +from memgpt.schemas.openai.chat_completion_response import ( + ChatCompletionResponse, + Choice, + Message, + UsageStatistics, +) + +# TODO this belongs in a controller! +from memgpt.server.rest_api.routers.v1.agents import send_message_to_agent +from memgpt.server.rest_api.utils import get_memgpt_server + +if TYPE_CHECKING: + pass + + from memgpt.server.server import SyncServer + from memgpt.utils import get_utc_time + +router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"]) + + +@router.post("/", response_model=ChatCompletionResponse) +async def create_chat_completion( + completion_request: ChatCompletionRequest = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """Send a message to a MemGPT agent via a /chat/completions completion_request + The bearer token will be used to identify the user. + The 'user' field in the completion_request should be set to the agent ID. + """ + actor = server.get_current_user() + agent_id = completion_request.user + if agent_id is None: + raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") + + messages = completion_request.messages + if messages is None: + raise HTTPException(status_code=400, detail="'messages' field must not be empty") + if len(messages) > 1: + raise HTTPException(status_code=400, detail="'messages' field must be a list of length 1") + if messages[0].role != "user": + raise HTTPException(status_code=400, detail="'messages[0].role' must be a 'user'") + + input_message = completion_request.messages[0] + if completion_request.stream: + print("Starting streaming OpenAI proxy response") + + # TODO(charles) support multimodal parts + assert isinstance(input_message.content, str) + + return await send_message_to_agent( + server=server, + agent_id=agent_id, + user_id=actor.id, + role=MessageRole(input_message.role), + message=input_message.content, + # Turn streaming ON + stream_steps=True, + stream_tokens=True, + # Turn on ChatCompletion mode (eg remaps send_message to content) + chat_completion_mode=True, + return_message_object=False, + ) + + else: + print("Starting non-streaming OpenAI proxy response") + + # TODO(charles) support multimodal parts + assert isinstance(input_message.content, str) + + response_messages = await send_message_to_agent( + server=server, + agent_id=agent_id, + user_id=actor.id, + role=MessageRole(input_message.role), + message=input_message.content, + # Turn streaming OFF + stream_steps=False, + stream_tokens=False, + return_message_object=False, + ) + # print(response_messages) + + # Concatenate all send_message outputs together + id = "" + visible_message_str = "" + created_at = None + for memgpt_msg in response_messages.messages: + assert isinstance(memgpt_msg, MemGPTMessage) + if isinstance(memgpt_msg, FunctionCall): + if memgpt_msg.name and memgpt_msg.name == "send_message": + try: + memgpt_function_call_args = json.loads(memgpt_msg.arguments) + visible_message_str += memgpt_function_call_args["message"] + id = memgpt_msg.id + created_at = memgpt_msg.date + except: + print(f"Failed to parse MemGPT message: {str(memgpt_msg)}") + else: + print(f"Skipping function_call: {str(memgpt_msg)}") + else: + print(f"Skipping message: {str(memgpt_msg)}") + + response = ChatCompletionResponse( + id=id, + created=created_at if created_at else get_utc_time(), + choices=[ + Choice( + finish_reason="stop", + index=0, + message=Message( + role="assistant", + content=visible_message_str, + ), + ) + ], + # TODO add real usage + usage=UsageStatistics( + completion_tokens=0, + prompt_tokens=0, + total_tokens=0, + ), + ) + return response diff --git a/memgpt/server/rest_api/routers/v1/__init__.py b/memgpt/server/rest_api/routers/v1/__init__.py new file mode 100644 index 00000000..17c942aa --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/__init__.py @@ -0,0 +1,15 @@ +from memgpt.server.rest_api.routers.v1.agents import router as agents_router +from memgpt.server.rest_api.routers.v1.blocks import router as blocks_router +from memgpt.server.rest_api.routers.v1.jobs import router as jobs_router +from memgpt.server.rest_api.routers.v1.llms import router as llm_router +from memgpt.server.rest_api.routers.v1.sources import router as sources_router +from memgpt.server.rest_api.routers.v1.tools import router as tools_router + +ROUTERS = [ + tools_router, + sources_router, + agents_router, + llm_router, + blocks_router, + jobs_router, +] diff --git a/memgpt/server/rest_api/routers/v1/agents.py b/memgpt/server/rest_api/routers/v1/agents.py new file mode 100644 index 00000000..38d52bf3 --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/agents.py @@ -0,0 +1,529 @@ +import asyncio +from datetime import datetime +from typing import Dict, List, Optional, Union + +from fastapi import APIRouter, Body, Depends, HTTPException, Query, status +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.responses import StreamingResponse + +from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from memgpt.schemas.enums import MessageRole, MessageStreamStatus +from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage +from memgpt.schemas.memgpt_request import MemGPTRequest +from memgpt.schemas.memgpt_response import MemGPTResponse +from memgpt.schemas.memory import ( + ArchivalMemorySummary, + CreateArchivalMemory, + Memory, + RecallMemorySummary, +) +from memgpt.schemas.message import Message, UpdateMessage +from memgpt.schemas.passage import Passage +from memgpt.schemas.source import Source +from memgpt.server.rest_api.interface import StreamingServerInterface +from memgpt.server.rest_api.utils import get_memgpt_server, sse_async_generator +from memgpt.server.server import SyncServer +from memgpt.utils import deduplicate + +# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally + + +router = APIRouter(prefix="/agents", tags=["agents"]) + + +@router.get("/", response_model=List[AgentState]) +def list_agents( + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + List all agents associated with a given user. + This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. + """ + actor = server.get_current_user() + + return server.list_agents(user_id=actor.id) + + +@router.post("/", response_model=AgentState) +def create_agent( + agent: CreateAgent = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Create a new agent with the specified configuration. + """ + actor = server.get_current_user() + agent.user_id = actor.id + + return server.create_agent(agent, user_id=actor.id) + + +@router.patch("/{agent_id}", response_model=AgentState) +def update_agent( + agent_id: str, + update_agent: UpdateAgentState = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """Update an exsiting agent""" + actor = server.get_current_user() + + update_agent.id = agent_id + return server.update_agent(update_agent, user_id=actor.id) + + +@router.get("/{agent_id}", response_model=AgentState) +def get_agent_state( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get the state of the agent. + """ + actor = server.get_current_user() + + if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id): + # agent does not exist + raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") + + return server.get_agent_state(user_id=actor.id, agent_id=agent_id) + + +@router.delete("/{agent_id}") +def delete_agent( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Delete an agent. + """ + actor = server.get_current_user() + + return server.delete_agent(user_id=actor.id, agent_id=agent_id) + + +@router.get("/{agent_id}/sources", response_model=List[Source]) +def get_agent_sources( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get the sources associated with an agent. + """ + server.get_current_user() + + return server.list_attached_sources(agent_id) + + +@router.get("/{agent_id}/memory/messages", response_model=List[Message]) +def get_agent_in_context_messages( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Retrieve the messages in the context of a specific agent. + """ + + return server.get_in_context_messages(agent_id=agent_id) + + +@router.get("/{agent_id}/memory", response_model=Memory) +def get_agent_memory( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Retrieve the memory state of a specific agent. + This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. + """ + + return server.get_agent_memory(agent_id=agent_id) + + +@router.patch("/{agent_id}/memory", response_model=Memory) +def update_agent_memory( + agent_id: str, + request: Dict = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Update the core memory of a specific agent. + This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. + """ + actor = server.get_current_user() + + memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request) + return memory + + +@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary) +def get_agent_recall_memory_summary( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Retrieve the summary of the recall memory of a specific agent. + """ + + return server.get_recall_memory_summary(agent_id=agent_id) + + +@router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary) +def get_agent_archival_memory_summary( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Retrieve the summary of the archival memory of a specific agent. + """ + + return server.get_archival_memory_summary(agent_id=agent_id) + + +@router.get("/{agent_id}/archival", response_model=List[Passage]) +def get_agent_archival_memory( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), + after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."), + before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."), + limit: Optional[int] = Query(None, description="How many results to include in the response."), +): + """ + Retrieve the memories in an agent's archival memory store (paginated query). + """ + actor = server.get_current_user() + + # TODO need to add support for non-postgres here + # chroma will throw: + # raise ValueError("Cannot run get_all_cursor with chroma") + + return server.get_agent_archival_cursor( + user_id=actor.id, + agent_id=agent_id, + after=after, + before=before, + limit=limit, + ) + + +@router.post("/{agent_id}/archival", response_model=List[Passage]) +def insert_agent_archival_memory( + agent_id: str, + request: CreateArchivalMemory = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Insert a memory into an agent's archival memory store. + """ + actor = server.get_current_user() + + return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text) + + +# TODO(ethan): query or path parameter for memory_id? +# @router.delete("/{agent_id}/archival") +@router.delete("/{agent_id}/archival/{memory_id}") +def delete_agent_archival_memory( + agent_id: str, + memory_id: str, + # memory_id: str = Query(..., description="Unique ID of the memory to be deleted."), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Delete a memory from an agent's archival memory store. + """ + actor = server.get_current_user() + + server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id) + return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) + + +@router.get("/{agent_id}/messages", response_model=List[Message]) +def get_agent_messages( + agent_id: str, + server: "SyncServer" = Depends(get_memgpt_server), + before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), + limit: int = Query(10, description="Maximum number of messages to retrieve."), + msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."), +): + """ + Retrieve message history for an agent. + """ + actor = server.get_current_user() + + return server.get_agent_recall_cursor( + user_id=actor.id, + agent_id=agent_id, + before=before, + limit=limit, + reverse=True, + return_message_object=msg_object, + ) + + +@router.patch("/{agent_id}/messages/{message_id}", response_model=Message) +def update_message( + agent_id: str, + message_id: str, + request: UpdateMessage = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Update the details of a message associated with an agent. + """ + assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}" + return server.update_agent_message(agent_id=agent_id, request=request) + + +@router.post("/{agent_id}/messages", response_model=MemGPTResponse) +async def send_message( + agent_id: str, + server: SyncServer = Depends(get_memgpt_server), + request: MemGPTRequest = Body(...), +): + """ + Process a user message and return the agent's response. + This endpoint accepts a message from a user and processes it through the agent. + It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True. + """ + actor = server.get_current_user() + + # TODO(charles): support sending multiple messages + assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}" + message = request.messages[0] + + return await send_message_to_agent( + server=server, + agent_id=agent_id, + user_id=actor.id, + role=message.role, + message=message.text, + stream_steps=request.stream_steps, + stream_tokens=request.stream_tokens, + return_message_object=request.return_message_object, + ) + + +# TODO: move this into server.py? +async def send_message_to_agent( + server: SyncServer, + agent_id: str, + user_id: str, + role: MessageRole, + message: str, + stream_steps: bool, + stream_tokens: bool, + return_message_object: bool, # Should be True for Python Client, False for REST API + chat_completion_mode: Optional[bool] = False, + timestamp: Optional[datetime] = None, + # related to whether or not we return `MemGPTMessage`s or `Message`s +) -> Union[StreamingResponse, MemGPTResponse]: + """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" + # TODO: @charles is this the correct way to handle? + include_final_message = True + + # determine role + if role == MessageRole.user: + message_func = server.user_message + elif role == MessageRole.system: + message_func = server.system_message + else: + raise HTTPException(status_code=500, detail=f"Bad role {role}") + + if not stream_steps and stream_tokens: + raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'") + + # For streaming response + try: + + # TODO: move this logic into server.py + + # Get the generator object off of the agent's streaming interface + # This will be attached to the POST SSE request used under-the-hood + memgpt_agent = server._get_or_load_agent(agent_id=agent_id) + streaming_interface = memgpt_agent.interface + if not isinstance(streaming_interface, StreamingServerInterface): + raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") + + # Enable token-streaming within the request if desired + streaming_interface.streaming_mode = stream_tokens + # "chatcompletion mode" does some remapping and ignores inner thoughts + streaming_interface.streaming_chat_completion_mode = chat_completion_mode + + # streaming_interface.allow_assistant_message = stream + # streaming_interface.function_call_legacy_mode = stream + + # Offload the synchronous message_func to a separate thread + streaming_interface.stream_start() + task = asyncio.create_task( + asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp) + ) + + if stream_steps: + if return_message_object: + # TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format + raise NotImplementedError + + # return a stream + return StreamingResponse( + sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message), + media_type="text/event-stream", + ) + + else: + # buffer the stream, then return the list + generated_stream = [] + async for message in streaming_interface.get_generator(): + assert ( + isinstance(message, MemGPTMessage) + or isinstance(message, LegacyMemGPTMessage) + or isinstance(message, MessageStreamStatus) + ), type(message) + generated_stream.append(message) + if message == MessageStreamStatus.done: + break + + # Get rid of the stream status messages + filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] + usage = await task + + # By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage + # If we want to convert these to Message, we can use the attached IDs + # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) + # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead + if return_message_object: + message_ids = [m.id for m in filtered_stream] + message_ids = deduplicate(message_ids) + message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] + return MemGPTResponse(messages=message_objs, usage=usage) + else: + return MemGPTResponse(messages=filtered_stream, usage=usage) + + except HTTPException: + raise + except Exception as e: + print(e) + import traceback + + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"{e}") + + +##### MISSING ####### + +# @router.post("/{agent_id}/command") +# def run_command( +# agent_id: "UUID", +# command: "AgentCommandRequest", +# +# server: "SyncServer" = Depends(get_memgpt_server), +# ): +# """ +# Execute a command on a specified agent. + +# This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately. + +# Raises an HTTPException for any processing errors. +# """ +# actor = server.get_current_user() +# +# response = server.run_command(user_id=actor.id, +# agent_id=agent_id, +# command=command.command) + +# return AgentCommandResponse(response=response) + +# @router.get("/{agent_id}/config") +# def get_agent_config( +# agent_id: "UUID", +# +# server: "SyncServer" = Depends(get_memgpt_server), +# ): +# """ +# Retrieve the configuration for a specific agent. + +# This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs. +# """ +# actor = server.get_current_user() +# +# if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id): +## agent does not exist +# raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") + +# agent_state = server.get_agent_config(user_id=actor.id, agent_id=agent_id) +## get sources +# attached_sources = server.list_attached_sources(agent_id=agent_id) + +## configs +# llm_config = LLMConfig(**vars(agent_state.llm_config)) +# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config)) + +# return GetAgentResponse( +# agent_state=AgentState( +# id=agent_state.id, +# name=agent_state.name, +# user_id=agent_state.user_id, +# llm_config=llm_config, +# embedding_config=embedding_config, +# state=agent_state.state, +# created_at=int(agent_state.created_at.timestamp()), +# tools=agent_state.tools, +# system=agent_state.system, +# metadata=agent_state._metadata, +# ), +# last_run_at=None, # TODO +# sources=attached_sources, +# ) + +# @router.patch("/{agent_id}/rename", response_model=GetAgentResponse) +# def update_agent_name( +# agent_id: "UUID", +# agent_rename: AgentRenameRequest, +# +# server: "SyncServer" = Depends(get_memgpt_server), +# ): +# """ +# Updates the name of a specific agent. + +# This changes the name of the agent in the database but does NOT edit the agent's persona. +# """ +# valid_name = agent_rename.agent_name +# actor = server.get_current_user() +# +# agent_state = server.rename_agent(user_id=actor.id, agent_id=agent_id, new_agent_name=valid_name) +## get sources +# attached_sources = server.list_attached_sources(agent_id=agent_id) +# llm_config = LLMConfig(**vars(agent_state.llm_config)) +# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config)) + +# return GetAgentResponse( +# agent_state=AgentState( +# id=agent_state.id, +# name=agent_state.name, +# user_id=agent_state.user_id, +# llm_config=llm_config, +# embedding_config=embedding_config, +# state=agent_state.state, +# created_at=int(agent_state.created_at.timestamp()), +# tools=agent_state.tools, +# system=agent_state.system, +# ), +# last_run_at=None, # TODO +# sources=attached_sources, +# ) + + +# @router.get("/{agent_id}/archival/all", response_model=GetAgentArchivalMemoryResponse) +# def get_agent_archival_memory_all( +# agent_id: "UUID", +# +# server: "SyncServer" = Depends(get_memgpt_server), +# ): +# """ +# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once). +# """ +# actor = server.get_current_user() +# +# archival_memories = server.get_all_archival_memories(user_id=actor.id, agent_id=agent_id) +# print("archival_memories:", archival_memories) +# archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories] +# return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects) diff --git a/memgpt/server/rest_api/routers/v1/blocks.py b/memgpt/server/rest_api/routers/v1/blocks.py new file mode 100644 index 00000000..23091d70 --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/blocks.py @@ -0,0 +1,73 @@ +from typing import TYPE_CHECKING, List, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from memgpt.schemas.block import Block, CreateBlock, UpdateBlock +from memgpt.server.rest_api.utils import get_memgpt_server +from memgpt.server.server import SyncServer + +if TYPE_CHECKING: + pass + +router = APIRouter(prefix="/blocks", tags=["blocks"]) + + +@router.get("/", response_model=List[Block]) +def list_blocks( + # query parameters + label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"), + templates_only: bool = Query(True, description="Whether to include only templates"), + name: Optional[str] = Query(None, description="Name of the block"), + server: SyncServer = Depends(get_memgpt_server), +): + actor = server.get_current_user() + + blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name) + if blocks is None: + return [] + return blocks + + +@router.post("/", response_model=Block) +def create_block( + create_block: CreateBlock = Body(...), + server: SyncServer = Depends(get_memgpt_server), +): + actor = server.get_current_user() + + create_block.user_id = actor.id + return server.create_block(user_id=actor.id, request=create_block) + + +@router.patch("/{block_id}", response_model=Block) +def update_block( + block_id: str, + updated_block: UpdateBlock = Body(...), + server: SyncServer = Depends(get_memgpt_server), +): + # actor = server.get_current_user() + + updated_block.id = block_id + return server.update_block(request=updated_block) + + +# TODO: delete should not return anything +@router.delete("/{block_id}", response_model=Block) +def delete_block( + block_id: str, + server: SyncServer = Depends(get_memgpt_server), +): + + return server.delete_block(block_id=block_id) + + +@router.get("/{block_id}", response_model=Block) +def get_block( + block_id: str, + server: SyncServer = Depends(get_memgpt_server), +): + + block = server.get_block(block_id=block_id) + if block is None: + raise HTTPException(status_code=404, detail="Block not found") + return block diff --git a/memgpt/server/rest_api/routers/v1/jobs.py b/memgpt/server/rest_api/routers/v1/jobs.py new file mode 100644 index 00000000..633d493a --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/jobs.py @@ -0,0 +1,46 @@ +from typing import List + +from fastapi import APIRouter, Depends + +from memgpt.schemas.job import Job +from memgpt.server.rest_api.utils import get_memgpt_server +from memgpt.server.server import SyncServer + +router = APIRouter(prefix="/jobs", tags=["jobs"]) + + +@router.get("/", response_model=List[Job]) +def list_jobs( + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + List all jobs. + """ + actor = server.get_current_user() + + # TODO: add filtering by status + return server.list_jobs(user_id=actor.id) + + +@router.get("/active", response_model=List[Job]) +def list_active_jobs( + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + List all active jobs. + """ + actor = server.get_current_user() + + return server.list_active_jobs(user_id=actor.id) + + +@router.get("/{job_id}", response_model=Job) +def get_job( + job_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get the status of a job. + """ + + return server.get_job(job_id=job_id) diff --git a/memgpt/server/rest_api/routers/v1/llms.py b/memgpt/server/rest_api/routers/v1/llms.py new file mode 100644 index 00000000..5b5e8d2a --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/llms.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, List + +from fastapi import APIRouter, Depends + +from memgpt.schemas.embedding_config import EmbeddingConfig +from memgpt.schemas.llm_config import LLMConfig +from memgpt.server.rest_api.utils import get_memgpt_server + +if TYPE_CHECKING: + from memgpt.server.server import SyncServer + +router = APIRouter(prefix="/models", tags=["models", "llms"]) + + +@router.get("/", response_model=List[LLMConfig]) +def list_llm_backends( + server: "SyncServer" = Depends(get_memgpt_server), +): + + return server.list_models() + + +@router.get("/embedding", response_model=List[EmbeddingConfig]) +def list_embedding_backends( + server: "SyncServer" = Depends(get_memgpt_server), +): + + return server.list_embedding_models() diff --git a/memgpt/server/rest_api/routers/v1/sources.py b/memgpt/server/rest_api/routers/v1/sources.py new file mode 100644 index 00000000..eeb1fc6c --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/sources.py @@ -0,0 +1,199 @@ +import os +import tempfile +from typing import List + +from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile + +from memgpt.schemas.document import Document +from memgpt.schemas.job import Job +from memgpt.schemas.passage import Passage +from memgpt.schemas.source import Source, SourceCreate, SourceUpdate +from memgpt.server.rest_api.utils import get_memgpt_server +from memgpt.server.server import SyncServer + +# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally + + +router = APIRouter(prefix="/sources", tags=["sources"]) + + +@router.get("/{source_id}", response_model=Source) +def get_source( + source_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get all sources + """ + actor = server.get_current_user() + + return server.get_source(source_id=source_id, user_id=actor.id) + + +@router.get("/name/{source_name}", response_model=str) +def get_source_id_by_name( + source_name: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get a source by name + """ + actor = server.get_current_user() + + source = server.get_source_id(source_name=source_name, user_id=actor.id) + return source + + +@router.get("/", response_model=List[Source]) +def list_sources( + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + List all data sources created by a user. + """ + actor = server.get_current_user() + + return server.list_all_sources(user_id=actor.id) + + +@router.post("/", response_model=Source) +def create_source( + source: SourceCreate, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Create a new data source. + """ + actor = server.get_current_user() + + return server.create_source(request=source, user_id=actor.id) + + +@router.patch("/{source_id}", response_model=Source) +def update_source( + source_id: str, + source: SourceUpdate, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Update the name or documentation of an existing data source. + """ + actor = server.get_current_user() + assert source.id == source_id, "Source ID in path must match ID in request body" + + return server.update_source(request=source, user_id=actor.id) + + +@router.delete("/{source_id}") +def delete_source( + source_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Delete a data source. + """ + actor = server.get_current_user() + + server.delete_source(source_id=source_id, user_id=actor.id) + + +@router.post("/{source_id}/attach", response_model=Source) +def attach_source_to_agent( + source_id: str, + agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Attach a data source to an existing agent. + """ + actor = server.get_current_user() + + source = server.ms.get_source(source_id=source_id, user_id=actor.id) + assert source is not None, f"Source with id={source_id} not found." + source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id) + return source + + +@router.post("/{source_id}/detach") +def detach_source_from_agent( + source_id: str, + agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."), + server: "SyncServer" = Depends(get_memgpt_server), +) -> None: + """ + Detach a data source from an existing agent. + """ + actor = server.get_current_user() + + server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id) + + +@router.post("/{source_id}/upload", response_model=Job) +def upload_file_to_source( + file: UploadFile, + source_id: str, + background_tasks: BackgroundTasks, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Upload a file to a data source. + """ + actor = server.get_current_user() + + source = server.ms.get_source(source_id=source_id, user_id=actor.id) + assert source is not None, f"Source with id={source_id} not found." + bytes = file.file.read() + + # create job + job = Job( + user_id=actor.id, + metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id}, + completed_at=None, + ) + job_id = job.id + server.ms.create_job(job) + + # create background task + background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes) + + # return job information + job = server.ms.get_job(job_id=job_id) + assert job is not None, "Job not found" + return job + + +@router.get("/{source_id}/passages", response_model=List[Passage]) +def list_passages( + source_id: str, + server: SyncServer = Depends(get_memgpt_server), +): + """ + List all passages associated with a data source. + """ + actor = server.get_current_user() + passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id) + return passages + + +@router.get("/{source_id}/documents", response_model=List[Document]) +def list_documents( + source_id: str, + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + List all documents associated with a data source. + """ + actor = server.get_current_user() + + documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id) + return documents + + +def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes): + # write the file to a temporary directory (deleted after the context manager exits) + with tempfile.TemporaryDirectory() as tmpdirname: + file_path = os.path.join(str(tmpdirname), str(file.filename)) + with open(file_path, "wb") as buffer: + buffer.write(bytes) + + server.load_file_to_source(source_id, file_path, job_id) diff --git a/memgpt/server/rest_api/routers/v1/tools.py b/memgpt/server/rest_api/routers/v1/tools.py new file mode 100644 index 00000000..7c2d803b --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/tools.py @@ -0,0 +1,103 @@ +from typing import List + +from fastapi import APIRouter, Body, Depends, HTTPException + +from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate +from memgpt.server.rest_api.utils import get_memgpt_server +from memgpt.server.server import SyncServer + +router = APIRouter(prefix="/tools", tags=["tools"]) + + +@router.delete("/{tool_id}") +def delete_tool( + tool_id: str, + server: SyncServer = Depends(get_memgpt_server), +): + """ + Delete a tool by name + """ + # actor = server.get_current_user() + server.delete_tool(tool_id=tool_id) + + +@router.get("/{tool_id}", response_model=Tool) +def get_tool( + tool_id: str, + server: SyncServer = Depends(get_memgpt_server), +): + """ + Get a tool by ID + """ + # actor = server.get_current_user() + + tool = server.get_tool(tool_id=tool_id) + if tool is None: + # return 404 error + raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.") + return tool + + +@router.get("/name/{tool_name}", response_model=str) +def get_tool_id( + tool_name: str, + server: SyncServer = Depends(get_memgpt_server), +): + """ + Get a tool ID by name + """ + actor = server.get_current_user() + + tool_id = server.get_tool_id(tool_name, user_id=actor.id) + if tool_id is None: + # return 404 error + raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") + return tool_id + + +@router.get("/", response_model=List[Tool]) +def list_all_tools( + server: SyncServer = Depends(get_memgpt_server), +): + """ + Get a list of all tools available to agents created by a user + """ + actor = server.get_current_user() + actor.id + + # TODO: add back when user-specific + return server.list_tools(user_id=actor.id) + # return server.ms.list_tools(user_id=None) + + +@router.post("/", response_model=Tool) +def create_tool( + tool: ToolCreate = Body(...), + update: bool = False, + server: SyncServer = Depends(get_memgpt_server), +): + """ + Create a new tool + """ + actor = server.get_current_user() + + return server.create_tool( + request=tool, + # update=update, + update=True, + user_id=actor.id, + ) + + +@router.patch("/{tool_id}", response_model=Tool) +def update_tool( + tool_id: str, + request: ToolUpdate = Body(...), + server: SyncServer = Depends(get_memgpt_server), +): + """ + Update an existing tool + """ + assert tool_id == request.id, "Tool ID in path must match tool ID in request body" + server.get_current_user() + return server.update_tool(request) diff --git a/memgpt/server/rest_api/routers/v1/users.py b/memgpt/server/rest_api/routers/v1/users.py new file mode 100644 index 00000000..6e197e1a --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/users.py @@ -0,0 +1,109 @@ +from typing import TYPE_CHECKING, List, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from memgpt.schemas.api_key import APIKey, APIKeyCreate +from memgpt.schemas.user import User, UserCreate +from memgpt.server.rest_api.utils import get_memgpt_server + +# from memgpt.server.schemas.users import ( +# CreateAPIKeyRequest, +# CreateAPIKeyResponse, +# CreateUserRequest, +# CreateUserResponse, +# DeleteAPIKeyResponse, +# DeleteUserResponse, +# GetAllUsersResponse, +# GetAPIKeysResponse, +# ) + +if TYPE_CHECKING: + from memgpt.schemas.user import User + from memgpt.server.server import SyncServer + + +router = APIRouter(prefix="/users", tags=["users", "admin"]) + + +@router.get("/", tags=["admin"], response_model=List[User]) +def get_all_users( + cursor: Optional[str] = Query(None), + limit: Optional[int] = Query(50), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get a list of all users in the database + """ + try: + next_cursor, users = server.ms.get_all_users(cursor=cursor, limit=limit) + # processed_users = [{"user_id": user.id} for user in users] + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return users + + +@router.post("/", tags=["admin"], response_model=User) +def create_user( + request: UserCreate = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Create a new user in the database + """ + + user = server.create_user(request) + return user + + +@router.delete("/", tags=["admin"], response_model=User) +def delete_user( + user_id: str = Query(..., description="The user_id key to be deleted."), + server: "SyncServer" = Depends(get_memgpt_server), +): + # TODO make a soft deletion, instead of a hard deletion + try: + user = server.ms.get_user(user_id=user_id) + if user is None: + raise HTTPException(status_code=404, detail=f"User does not exist") + server.ms.delete_user(user_id=user_id) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return user + + +@router.post("/keys", response_model=APIKey) +def create_new_api_key( + create_key: APIKeyCreate = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Create a new API key for a user + """ + api_key = server.create_api_key(create_key) + return api_key + + +@router.get("/keys", response_model=List[APIKey]) +def get_api_keys( + user_id: str = Query(..., description="The unique identifier of the user."), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get a list of all API keys for a user + """ + if server.ms.get_user(user_id=user_id) is None: + raise HTTPException(status_code=404, detail=f"User does not exist") + api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id) + return api_keys + + +@router.delete("/keys", response_model=APIKey) +def delete_api_key( + api_key: str = Query(..., description="The API key to be deleted."), + server: "SyncServer" = Depends(get_memgpt_server), +): + return server.delete_api_key(api_key) diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 7e224d77..42ae5732 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -12,27 +12,24 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from starlette.middleware.cors import CORSMiddleware from memgpt.server.constants import REST_DEFAULT_PORT -from memgpt.server.rest_api.admin.agents import setup_agents_admin_router -from memgpt.server.rest_api.admin.tools import setup_tools_index_router -from memgpt.server.rest_api.admin.users import setup_admin_router -from memgpt.server.rest_api.agents.index import setup_agents_index_router -from memgpt.server.rest_api.agents.memory import setup_agents_memory_router -from memgpt.server.rest_api.agents.message import setup_agents_message_router -from memgpt.server.rest_api.auth.index import setup_auth_router -from memgpt.server.rest_api.block.index import setup_block_index_router -from memgpt.server.rest_api.config.index import setup_config_index_router +from memgpt.server.rest_api.auth.index import ( + setup_auth_router, # TODO: probably remove right? +) from memgpt.server.rest_api.interface import StreamingServerInterface -from memgpt.server.rest_api.jobs.index import setup_jobs_index_router -from memgpt.server.rest_api.models.index import setup_models_index_router -from memgpt.server.rest_api.openai_assistants.assistants import ( - setup_openai_assistant_router, +from memgpt.server.rest_api.routers.openai.assistants.assistants import ( + router as openai_assistants_router, ) -from memgpt.server.rest_api.openai_chat_completions.chat_completions import ( - setup_openai_chat_completions_router, +from memgpt.server.rest_api.routers.openai.assistants.threads import ( + router as openai_threads_router, +) +from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import ( + router as openai_chat_completions_router, +) +from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes +from memgpt.server.rest_api.routers.v1.users import ( + router as users_router, # TODO: decide on admin ) -from memgpt.server.rest_api.sources.index import setup_sources_index_router from memgpt.server.rest_api.static_files import mount_static_files -from memgpt.server.rest_api.tools.index import setup_user_tools_index_router from memgpt.server.server import SyncServer from memgpt.settings import settings @@ -44,8 +41,6 @@ Start the server with: poetry run uvicorn server:app --reload """ -# interface: QueuingInterface = QueuingInterface() -# interface: StreamingServerInterface = StreamingServerInterface() interface: StreamingServerInterface = StreamingServerInterface server: SyncServer = SyncServer(default_interface_factory=lambda: interface()) @@ -66,10 +61,9 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security raise HTTPException(status_code=401, detail="Unauthorized") -ADMIN_PREFIX = "/admin" -ADMIN_API_PREFIX = "/api/admin" -API_PREFIX = "/api" -OPENAI_API_PREFIX = "/v1" +ADMIN_PREFIX = "/v1/admin" +API_PREFIX = "/v1" +OPENAI_API_PREFIX = "/openai" app = FastAPI() @@ -81,36 +75,51 @@ app.add_middleware( allow_headers=["*"], ) +# v1_routes are the MemGPT API routes +for route in v1_routes: + app.include_router(route, prefix=API_PREFIX) + # this gives undocumented routes for "latest" and bare api calls. + # we should always tie this to the newest version of the api. + app.include_router(route, prefix="", include_in_schema=False) + app.include_router(route, prefix="/latest", include_in_schema=False) + +# admin/users +app.include_router(users_router, prefix=ADMIN_PREFIX) + +# openai +app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX) +app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX) +app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX) + # /api/auth endpoints app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX) -# /admin/users endpoints -app.include_router(setup_admin_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)]) -app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)]) +# # Serve static files +# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") +# app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="static") -# /api/admin/agents endpoints -app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)]) -# /api/agents endpoints -app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_jobs_index_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) +# # Serve favicon +# @app.get("/favicon.ico") +# async def favicon(): +# return FileResponse(os.path.join(static_files_path, "favicon.ico")) -# /api/config endpoints -app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX) -# /v1/assistants endpoints -app.include_router(setup_openai_assistant_router(server, interface), prefix=OPENAI_API_PREFIX) +# # Middleware to handle API routes first +# @app.middleware("http") +# async def handle_api_routes(request: Request, call_next): +# if request.url.path.startswith(("/v1/", "/openai/")): +# response = await call_next(request) +# if response.status_code != 404: +# return response +# return await serve_spa(request.url.path) + + +# # Catch-all route for SPA +# async def serve_spa(full_path: str): +# return FileResponse(os.path.join(static_files_path, "index.html")) -# /v1/chat/completions endpoints -app.include_router(setup_openai_chat_completions_router(server, interface, password), prefix=OPENAI_API_PREFIX) -# / static files mount_static_files(app) diff --git a/memgpt/server/rest_api/sources/__init__.py b/memgpt/server/rest_api/sources/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py deleted file mode 100644 index 2eb7b3c1..00000000 --- a/memgpt/server/rest_api/sources/index.py +++ /dev/null @@ -1,262 +0,0 @@ -import os -import tempfile -from functools import partial -from typing import List - -from fastapi import ( - APIRouter, - BackgroundTasks, - Body, - Depends, - HTTPException, - Query, - UploadFile, -) - -from memgpt.schemas.document import Document -from memgpt.schemas.job import Job -from memgpt.schemas.passage import Passage - -# schemas -from memgpt.schemas.source import Source, SourceCreate, SourceUpdate, UploadFile -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - -""" -Implement the following functions: -* List all available sources -* Create a new source -* Delete a source -* Upload a file to a server that is loaded into a specific source -* Paginated get all passages from a source -* Paginated get all documents from a source -* Attach a source to an agent -""" - - -# class ListSourcesResponse(BaseModel): -# sources: List[SourceModel] = Field(..., description="List of available sources.") -# -# -# class CreateSourceRequest(BaseModel): -# name: str = Field(..., description="The name of the source.") -# description: Optional[str] = Field(None, description="The description of the source.") -# -# -# class UploadFileToSourceRequest(BaseModel): -# file: UploadFile = Field(..., description="The file to upload.") -# -# -# class UploadFileToSourceResponse(BaseModel): -# source: SourceModel = Field(..., description="The source the file was uploaded to.") -# added_passages: int = Field(..., description="The number of passages added to the source.") -# added_documents: int = Field(..., description="The number of documents added to the source.") -# -# -# class GetSourcePassagesResponse(BaseModel): -# passages: List[PassageModel] = Field(..., description="List of passages from the source.") -# -# -# class GetSourceDocumentsResponse(BaseModel): -# documents: List[DocumentModel] = Field(..., description="List of documents from the source.") - - -def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes): - # write the file to a temporary directory (deleted after the context manager exits) - with tempfile.TemporaryDirectory() as tmpdirname: - file_path = os.path.join(tmpdirname, file.filename) - with open(file_path, "wb") as buffer: - buffer.write(bytes) - - server.load_file_to_source(source_id, file_path, job_id) - - -def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.get("/sources/{source_id}", tags=["sources"], response_model=Source) - async def get_source( - source_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get all sources - """ - interface.clear() - source = server.get_source(source_id=source_id, user_id=user_id) - return source - - @router.get("/sources/name/{source_name}", tags=["sources"], response_model=str) - async def get_source_id_by_name( - source_name: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get a source by name - """ - interface.clear() - source = server.get_source_id(source_name=source_name, user_id=user_id) - return source - - @router.get("/sources", tags=["sources"], response_model=List[Source]) - async def list_sources( - user_id: str = Depends(get_current_user_with_server), - ): - """ - List all data sources created by a user. - """ - # Clear the interface - interface.clear() - - try: - sources = server.list_all_sources(user_id=user_id) - return sources - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") - - @router.post("/sources", tags=["sources"], response_model=Source) - async def create_source( - request: SourceCreate = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Create a new data source. - """ - interface.clear() - try: - return server.create_source(request=request, user_id=user_id) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") - - @router.post("/sources/{source_id}", tags=["sources"], response_model=Source) - async def update_source( - source_id: str, - request: SourceUpdate = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Update the name or documentation of an existing data source. - """ - interface.clear() - try: - return server.update_source(request=request, user_id=user_id) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") - - @router.delete("/sources/{source_id}", tags=["sources"]) - async def delete_source( - source_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Delete a data source. - """ - interface.clear() - try: - server.delete_source(source_id=source_id, user_id=user_id) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") - - @router.post("/sources/{source_id}/attach", tags=["sources"], response_model=Source) - async def attach_source_to_agent( - source_id: str, - agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Attach a data source to an existing agent. - """ - interface.clear() - assert isinstance(agent_id, str), f"Expected agent_id to be a UUID, got {agent_id}" - assert isinstance(user_id, str), f"Expected user_id to be a UUID, got {user_id}" - source = server.ms.get_source(source_id=source_id, user_id=user_id) - source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=user_id) - return source - - @router.post("/sources/{source_id}/detach", tags=["sources"]) - async def detach_source_from_agent( - source_id: str, - agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Detach a data source from an existing agent. - """ - server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id) - - @router.get("/sources/status/{job_id}", tags=["sources"], response_model=Job) - async def get_job( - job_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get the status of a job. - """ - job = server.get_job(job_id=job_id) - if job is None: - raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.") - return job - - @router.post("/sources/{source_id}/upload", tags=["sources"], response_model=Job) - async def upload_file_to_source( - # file: UploadFile = UploadFile(..., description="The file to upload."), - file: UploadFile, - source_id: str, - background_tasks: BackgroundTasks, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Upload a file to a data source. - """ - interface.clear() - source = server.ms.get_source(source_id=source_id, user_id=user_id) - bytes = file.file.read() - - # create job - # TODO: create server function - job = Job(user_id=user_id, metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id}) - job_id = job.id - server.ms.create_job(job) - - # create background task - background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes) - - # return job information - job = server.ms.get_job(job_id=job_id) - return job - - @router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=List[Passage]) - async def list_passages( - source_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - List all passages associated with a data source. - """ - # TODO: check if paginated? - passages = server.list_data_source_passages(user_id=user_id, source_id=source_id) - return passages - - @router.get("/sources/{source_id}/documents", tags=["sources"], response_model=List[Document]) - async def list_documents( - source_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - List all documents associated with a data source. - """ - documents = server.list_data_source_documents(user_id=user_id, source_id=source_id) - return documents - - return router diff --git a/memgpt/server/rest_api/static_files.py b/memgpt/server/rest_api/static_files.py index 494f6a52..b6899fca 100644 --- a/memgpt/server/rest_api/static_files.py +++ b/memgpt/server/rest_api/static_files.py @@ -21,7 +21,8 @@ def mount_static_files(app: FastAPI): static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") if os.path.exists(static_files_path): app.mount( - "/", + # "/", + "/app", SPAStaticFiles( directory=static_files_path, html=True, diff --git a/memgpt/server/rest_api/tools/__init__.py b/memgpt/server/rest_api/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/memgpt/server/rest_api/tools/index.py b/memgpt/server/rest_api/tools/index.py deleted file mode 100644 index 6072b133..00000000 --- a/memgpt/server/rest_api/tools/index.py +++ /dev/null @@ -1,98 +0,0 @@ -from functools import partial -from typing import List - -from fastapi import APIRouter, Body, Depends, HTTPException - -from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate -from memgpt.server.rest_api.auth_token import get_current_user -from memgpt.server.rest_api.interface import QueuingInterface -from memgpt.server.server import SyncServer - -router = APIRouter() - - -def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str): - get_current_user_with_server = partial(partial(get_current_user, server), password) - - @router.delete("/tools/{tool_id}", tags=["tools"]) - async def delete_tool( - tool_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Delete a tool by name - """ - # Clear the interface - interface.clear() - server.delete_tool(tool_id) - - @router.get("/tools/{tool_id}", tags=["tools"], response_model=Tool) - async def get_tool( - tool_id: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get a tool by name - """ - # Clear the interface - interface.clear() - tool = server.get_tool(tool_id) - if tool is None: - # return 404 error - raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.") - return tool - - @router.get("/tools/name/{tool_name}", tags=["tools"], response_model=str) - async def get_tool_id( - tool_name: str, - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get a tool by name - """ - # Clear the interface - interface.clear() - tool = server.get_tool_id(tool_name, user_id=user_id) - if tool is None: - # return 404 error - raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") - return tool - - @router.get("/tools", tags=["tools"], response_model=List[Tool]) - async def list_all_tools( - user_id: str = Depends(get_current_user_with_server), - ): - """ - Get a list of all tools available to agents created by a user - """ - # Clear the interface - interface.clear() - return server.list_tools(user_id) - - @router.post("/tools", tags=["tools"], response_model=Tool) - async def create_tool( - request: ToolCreate = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Create a new tool - """ - return server.create_tool(request, user_id=user_id) - - @router.post("/tools/{tool_id}", tags=["tools"], response_model=Tool) - async def update_tool( - tool_id: str, - request: ToolUpdate = Body(...), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Update an existing tool - """ - try: - # TODO: check that the user has access to this tool? - return server.update_tool(request) - except Exception as e: - print(e) - raise HTTPException(status_code=500, detail=f"Failed to update tool: {e}") - - return router diff --git a/memgpt/server/rest_api/utils.py b/memgpt/server/rest_api/utils.py index 5ad4228b..fdbccdc7 100644 --- a/memgpt/server/rest_api/utils.py +++ b/memgpt/server/rest_api/utils.py @@ -5,6 +5,9 @@ from typing import AsyncGenerator, Union from pydantic import BaseModel +from memgpt.server.rest_api.interface import StreamingServerInterface +from memgpt.server.server import SyncServer + # from memgpt.orm.user import User # from memgpt.orm.utilities import get_db_session @@ -51,3 +54,13 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True): if finish_message: # Signal that the stream is complete yield sse_formatter(SSE_FINISH_MSG) + + +# TODO: why does this double up the interface? +def get_memgpt_server() -> SyncServer: + server = SyncServer(default_interface_factory=lambda: StreamingServerInterface()) + return server + + +def get_current_interface() -> StreamingServerInterface: + return StreamingServerInterface diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 60be88d2..4d6feffe 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -889,10 +889,10 @@ class SyncServer(Server): self, user_id: Optional[str] = None, label: Optional[str] = None, - template: Optional[bool] = None, + template: bool = True, name: Optional[str] = None, id: Optional[str] = None, - ): + ) -> Optional[List[Block]]: return self.ms.get_blocks(user_id=user_id, label=label, template=template, name=name, id=id) @@ -1550,14 +1550,14 @@ class SyncServer(Server): return sources_with_metadata - def get_tool(self, tool_id: str) -> Tool: + def get_tool(self, tool_id: str) -> Optional[Tool]: """Get tool by ID.""" return self.ms.get_tool(tool_id=tool_id) - def get_tool_id(self, name: str, user_id: str) -> str: + def get_tool_id(self, name: str, user_id: str) -> Optional[str]: """Get tool ID from name and user_id.""" tool = self.ms.get_tool(tool_name=name, user_id=user_id) - if not tool: + if not tool or tool.id is None: return None return tool.id @@ -1770,3 +1770,38 @@ class SyncServer(Server): # Get the current message memgpt_agent = self._get_or_load_agent(agent_id=agent_id) return memgpt_agent.retry_message() + + # TODO(ethan) wire back to real method in future ORM PR + def get_current_user(self) -> User: + """Returns the currently authed user. + + Since server is the core gateway this needs to pass through server as the + first touchpoint. + """ + # NOTE: same code as local client to get the default user + config = MemGPTConfig.load() + user_id = config.anon_clientid + user = self.get_user(user_id) + + if not user: + user = self.create_user(UserCreate()) + + # # update config + config.anon_clientid = str(user.id) + config.save() + + return user + + def list_models(self) -> List[LLMConfig]: + """List available models""" + + # TODO support multiple models + llm_config = self.server_llm_config + return [llm_config] + + def list_embedding_models(self) -> List[EmbeddingConfig]: + """List available embedding models""" + + # TODO support multiple models + embedding_config = self.server_embedding_config + return [embedding_config] diff --git a/tests/test_tools.py b/tests/test_tools.py index 8bcc0b21..4363f22f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -2,12 +2,14 @@ import os import threading import time import uuid +from typing import Union import pytest from dotenv import load_dotenv from memgpt import Admin, create_client from memgpt.agent import Agent +from memgpt.client.client import LocalClient, RESTClient from memgpt.constants import DEFAULT_PRESET from memgpt.schemas.memory import ChatMemory @@ -86,7 +88,7 @@ def agent(client): client.delete_agent(agent_state.id) -def test_create_tool(client): +def test_create_tool(client: Union[LocalClient, RESTClient]): """Test creation of a simple tool""" def print_tool(message: str): @@ -111,7 +113,9 @@ def test_create_tool(client): print(f"Updated tools {[t.name for t in tools]}") # check tool id - tool = client.get_tool(tool.name) + tool = client.get_tool(tool.id) + assert tool is not None, "Expected tool to be created" + assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}" # TODO: add back once we fix admin client tool creation