feat: add V1 route refactor from integration branch into separate PR (#1734)

This commit is contained in:
Charles Packer
2024-09-09 20:49:59 -07:00
committed by GitHub
parent cb20fbfab9
commit 635fb1cc66
41 changed files with 2053 additions and 1830 deletions

View File

@@ -1,38 +1,36 @@
name: Code Formatter (Black)
on:
pull_request:
paths:
- '**.py'
workflow_dispatch:
jobs:
black-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }} # Checkout the PR branch
fetch-depth: 0 # Fetch all history for all branches and tags
- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev" # TODO: change this to --group dev when PR #842 lands
- name: Run Black
id: black
run: poetry run black --check .
continue-on-error: true
- name: Auto-fix with Black and commit
if: steps.black.outcome == 'failure' || ${{ !contains(github.event.issue.labels.*.name, 'check-only') }}
if: steps.black.outcome == 'failure' && !contains(github.event.pull_request.labels.*.name, 'check-only')
run: |
poetry run black .
git config --local user.email "action@github.com"
git config --local user.name "GitHub Action"
git commit -am "Apply Black formatting" || echo "No changes to commit"
git diff --quiet && git diff --staged --quiet || (git add -A && git commit -m "Apply Black formatting")
git push
- name: Error if 'check-only' label is present
if: steps.black.outcome == 'failure' && contains(github.event.issue.labels.*.name, 'check-only')
if: steps.black.outcome == 'failure' && contains(github.event.pull_request.labels.*.name, 'check-only')
run: echo "Black formatting check failed. Please run 'black .' locally to fix formatting issues." && exit 1

View File

@@ -1,39 +1,48 @@
name: Code Formatter (isort)
on:
pull_request:
paths:
- '**.py'
workflow_dispatch:
jobs:
isort-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }} # Checkout the PR branch
fetch-depth: 0 # Fetch all history for all branches and tags
- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev" # TODO: change this to --group dev when PR #842 lands
- name: Run isort
id: isort
run: poetry run isort --profile black --check-only .
run: |
output=$(poetry run isort --profile black --check-only --diff . | grep -v "Skipped" || true)
echo "$output"
if [ -n "$output" ]; then
echo "isort_changed=true" >> $GITHUB_OUTPUT
else
echo "isort_changed=false" >> $GITHUB_OUTPUT
fi
continue-on-error: true
- name: Auto-fix with isort and commit
if: steps.isort.outcome == 'failure' || ${{ !contains(github.event.issue.labels.*.name, 'check-only') }}
if: steps.isort.outputs.isort_changed == 'true' && !contains(github.event.pull_request.labels.*.name, 'check-only')
run: |
poetry run isort --profile black .
git config --local user.email "action@github.com"
git config --local user.name "GitHub Action"
git commit -am "Apply isort import ordering" || echo "No changes to commit"
git push
- name: Error if 'check-only' label is present
if: steps.isort.outcome == 'failure' && contains(github.event.issue.labels.*.name, 'check-only')
if [[ -n $(git status -s) ]]; then
git add -A
git commit -m "Apply isort import ordering"
git push
else
echo "No changes to commit"
fi
- name: Error if 'check-only' label is present and changes are needed
if: steps.isort.outputs.isort_changed == 'true' && contains(github.event.pull_request.labels.*.name, 'check-only')
run: echo "Isort check failed. Please run 'isort .' locally to fix import orders." && exit 1

View File

@@ -16,8 +16,14 @@ class Admin:
- Generating user keys
"""
def __init__(self, base_url: str, token: str):
def __init__(
self,
base_url: str,
token: str,
api_prefix: str = "v1",
):
self.base_url = base_url
self.api_prefix = api_prefix
self.token = token
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
@@ -27,35 +33,36 @@ class Admin:
params["cursor"] = str(cursor)
if limit:
params["limit"] = limit
response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/users", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return [User(**user) for user in response.json()]
def create_key(self, user_id: str, key_name: Optional[str] = None) -> APIKey:
request = APIKeyCreate(user_id=user_id, name=key_name)
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=request.model_dump())
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users/keys", headers=self.headers, json=request.model_dump())
if response.status_code != 200:
raise HTTPError(response.json())
return APIKey(**response.json())
def get_keys(self, user_id: str) -> List[APIKey]:
params = {"user_id": str(user_id)}
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return [APIKey(**key) for key in response.json()]
def delete_key(self, api_key: str) -> APIKey:
params = {"api_key": api_key}
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return APIKey(**response.json())
def create_user(self, name: Optional[str] = None) -> User:
request = UserCreate(name=name)
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=request.model_dump())
print("YYYYY Pinging", f"{self.base_url}/{self.api_prefix}/admin/users")
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
if response.status_code != 200:
raise HTTPError(response.json())
response_json = response.json()
@@ -63,7 +70,7 @@ class Admin:
def delete_user(self, user_id: str) -> User:
params = {"user_id": str(user_id)}
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/users", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return User(**response.json())
@@ -114,23 +121,23 @@ class Admin:
CreateToolRequest(**data) # validate
# make REST request
response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/tools", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return ToolModel(**response.json())
def list_tools(self):
response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/tools", headers=self.headers)
return ListToolsResponse(**response.json()).tools
def delete_tool(self, name: str):
response = requests.delete(f"{self.base_url}/admin/tools/{name}", headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/tools/{name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete tool: {response.text}")
return response.json()
def get_tool(self, name: str):
response = requests.get(f"{self.base_url}/admin/tools/{name}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/tools/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:

View File

@@ -1,6 +1,6 @@
import logging
import time
from typing import Dict, Generator, List, Optional, Union
from typing import Callable, Dict, Generator, List, Optional, Union
import requests
@@ -185,7 +185,7 @@ class AbstractClient(object):
self,
id: str,
name: Optional[str] = None,
func: Optional[callable] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
) -> Tool:
raise NotImplementedError
@@ -271,6 +271,7 @@ class RESTClient(AbstractClient):
self,
base_url: str,
token: str,
api_prefix: str = "v1",
debug: bool = False,
):
"""
@@ -283,10 +284,11 @@ class RESTClient(AbstractClient):
"""
super().__init__(debug=debug)
self.base_url = base_url
self.api_prefix = api_prefix
self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
def list_agents(self) -> List[AgentState]:
response = requests.get(f"{self.base_url}/api/agents", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers)
return [AgentState(**agent) for agent in response.json()]
def agent_exists(self, agent_id: str) -> bool:
@@ -301,7 +303,7 @@ class RESTClient(AbstractClient):
exists (bool): `True` if the agent exists, `False` otherwise
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", headers=self.headers)
if response.status_code == 404:
# not found error
return False
@@ -375,7 +377,7 @@ class RESTClient(AbstractClient):
embedding_config=embedding_config,
)
response = requests.post(f"{self.base_url}/api/agents", json=request.model_dump(), headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
return AgentState(**response.json())
@@ -399,7 +401,7 @@ class RESTClient(AbstractClient):
tool_call_id=tool_call_id,
)
response = requests.patch(
f"{self.base_url}/api/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to update message: {response.text}")
@@ -448,7 +450,7 @@ class RESTClient(AbstractClient):
message_ids=message_ids,
memory=memory,
)
response = requests.post(f"{self.base_url}/api/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())
@@ -471,7 +473,7 @@ class RESTClient(AbstractClient):
Args:
agent_id (str): ID of the agent to delete
"""
response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/agents/{str(agent_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete agent: {response.text}"
def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState:
@@ -484,7 +486,7 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State representation of the agent
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", headers=self.headers)
assert response.status_code == 200, f"Failed to get agent: {response.text}"
return AgentState(**response.json())
@@ -512,7 +514,7 @@ class RESTClient(AbstractClient):
Returns:
memory (Memory): In-context memory of the agent
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get in-context memory: {response.text}")
return Memory(**response.json())
@@ -529,7 +531,9 @@ class RESTClient(AbstractClient):
"""
memory_update_dict = {section: value}
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/memory", json=memory_update_dict, headers=self.headers)
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory", json=memory_update_dict, headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to update in-context memory: {response.text}")
return Memory(**response.json())
@@ -545,7 +549,7 @@ class RESTClient(AbstractClient):
summary (ArchivalMemorySummary): Summary of the archival memory
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/archival", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/archival", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get archival memory summary: {response.text}")
return ArchivalMemorySummary(**response.json())
@@ -560,7 +564,7 @@ class RESTClient(AbstractClient):
Returns:
summary (RecallMemorySummary): Summary of the recall memory
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/recall", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/recall", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get recall memory summary: {response.text}")
return RecallMemorySummary(**response.json())
@@ -575,7 +579,7 @@ class RESTClient(AbstractClient):
Returns:
messages (List[Message]): List of in-context messages
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/messages", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/messages", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get in-context messages: {response.text}")
return [Message(**message) for message in response.json()]
@@ -620,7 +624,7 @@ class RESTClient(AbstractClient):
params["before"] = str(before)
if after:
params["after"] = str(after)
response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{str(agent_id)}/archival", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to get archival memory: {response.text}"
return [Passage(**passage) for passage in response.json()]
@@ -636,7 +640,9 @@ class RESTClient(AbstractClient):
passages (List[Passage]): List of inserted passages
"""
request = CreateArchivalMemory(text=memory)
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", headers=self.headers, json=request.model_dump())
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/archival", headers=self.headers, json=request.model_dump()
)
if response.status_code != 200:
raise ValueError(f"Failed to insert archival memory: {response.text}")
return [Passage(**passage) for passage in response.json()]
@@ -649,7 +655,7 @@ class RESTClient(AbstractClient):
agent_id (str): ID of the agent
memory_id (str): ID of the memory
"""
response = requests.delete(f"{self.base_url}/api/agents/{agent_id}/archival/{memory_id}", headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/archival/{memory_id}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete archival memory: {response.text}"
# messages (recall memory)
@@ -671,7 +677,7 @@ class RESTClient(AbstractClient):
"""
params = {"before": before, "after": after, "limit": limit, "msg_object": True}
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages", params=params, headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get messages: {response.text}")
return [Message(**message) for message in response.json()]
@@ -708,9 +714,11 @@ class RESTClient(AbstractClient):
from memgpt.client.streaming import _sse_post
request.return_message_object = False
return _sse_post(f"{self.base_url}/api/agents/{agent_id}/messages", request.model_dump(), self.headers)
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", request.model_dump(), self.headers)
else:
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers)
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to send message: {response.text}")
return MemGPTResponse(**response.json())
@@ -719,7 +727,7 @@ class RESTClient(AbstractClient):
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
params = {"label": label, "templates_only": templates_only}
response = requests.get(f"{self.base_url}/api/blocks", params=params, headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list blocks: {response.text}")
@@ -732,7 +740,7 @@ class RESTClient(AbstractClient):
def create_block(self, label: str, name: str, text: str) -> Block: #
request = CreateBlock(label=label, name=name, value=text)
response = requests.post(f"{self.base_url}/api/blocks", json=request.model_dump(), headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create block: {response.text}")
if request.label == "human":
@@ -744,13 +752,13 @@ class RESTClient(AbstractClient):
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
request = UpdateBlock(id=block_id, name=name, value=text)
response = requests.post(f"{self.base_url}/api/blocks/{block_id}", json=request.model_dump(), headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update block: {response.text}")
return Block(**response.json())
def get_block(self, block_id: str) -> Block:
response = requests.get(f"{self.base_url}/api/blocks/{block_id}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
@@ -759,7 +767,7 @@ class RESTClient(AbstractClient):
def get_block_id(self, name: str, label: str) -> str:
params = {"name": name, "label": label}
response = requests.get(f"{self.base_url}/api/blocks", params=params, headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get block ID: {response.text}")
blocks = [Block(**block) for block in response.json()]
@@ -770,7 +778,7 @@ class RESTClient(AbstractClient):
return blocks[0].id
def delete_block(self, id: str) -> Block:
response = requests.delete(f"{self.base_url}/api/blocks/{id}", headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/blocks/{id}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete block: {response.text}"
if response.status_code != 200:
raise ValueError(f"Failed to delete block: {response.text}")
@@ -811,7 +819,7 @@ class RESTClient(AbstractClient):
human (Human): Updated human block
"""
request = UpdateHuman(id=human_id, name=name, value=text)
response = requests.post(f"{self.base_url}/api/blocks/{human_id}", json=request.model_dump(), headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{human_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update human: {response.text}")
return Human(**response.json())
@@ -851,7 +859,7 @@ class RESTClient(AbstractClient):
persona (Persona): Updated persona block
"""
request = UpdatePersona(id=persona_id, name=name, value=text)
response = requests.post(f"{self.base_url}/api/blocks/{persona_id}", json=request.model_dump(), headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{persona_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update persona: {response.text}")
return Persona(**response.json())
@@ -934,7 +942,7 @@ class RESTClient(AbstractClient):
Returns:
source (Source): Source
"""
response = requests.get(f"{self.base_url}/api/sources/{source_id}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get source: {response.text}")
return Source(**response.json())
@@ -949,7 +957,7 @@ class RESTClient(AbstractClient):
Returns:
source_id (str): ID of the source
"""
response = requests.get(f"{self.base_url}/api/sources/name/{source_name}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/name/{source_name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get source ID: {response.text}")
return response.json()
@@ -961,7 +969,7 @@ class RESTClient(AbstractClient):
Returns:
sources (List[Source]): List of sources
"""
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list sources: {response.text}")
return [Source(**source) for source in response.json()]
@@ -973,21 +981,21 @@ class RESTClient(AbstractClient):
Args:
source_id (str): ID of the source
"""
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/sources/{str(source_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete source: {response.text}"
def get_job(self, job_id: str) -> Job:
response = requests.get(f"{self.base_url}/api/jobs/{job_id}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get job: {response.text}")
return Job(**response.json())
def list_jobs(self):
response = requests.get(f"{self.base_url}/api/jobs", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
return [Job(**job) for job in response.json()]
def list_active_jobs(self):
response = requests.get(f"{self.base_url}/api/jobs/active", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs/active", headers=self.headers)
return [Job(**job) for job in response.json()]
def load_data(self, connector: DataConnector, source_name: str):
@@ -1008,7 +1016,7 @@ class RESTClient(AbstractClient):
files = {"file": open(filename, "rb")}
# create job
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/upload", files=files, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to upload file to source: {response.text}")
@@ -1035,7 +1043,7 @@ class RESTClient(AbstractClient):
source (Source): Created source
"""
payload = {"name": name}
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers)
response_json = response.json()
return Source(**response_json)
@@ -1049,7 +1057,7 @@ class RESTClient(AbstractClient):
Returns:
sources (List[Source]): List of sources
"""
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/sources", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list attached sources: {response.text}")
return [Source(**source) for source in response.json()]
@@ -1066,7 +1074,7 @@ class RESTClient(AbstractClient):
source (Source): Updated source
"""
request = SourceUpdate(id=source_id, name=name)
response = requests.post(f"{self.base_url}/api/sources/{source_id}", json=request.model_dump(), headers=self.headers)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update source: {response.text}")
return Source(**response.json())
@@ -1081,13 +1089,13 @@ class RESTClient(AbstractClient):
source_name (str): Name of the source
"""
params = {"agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/attach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
def detach_source(self, source_id: str, agent_id: str):
"""Detach a source from an agent"""
params = {"agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
# server configuration commands
@@ -1099,7 +1107,7 @@ class RESTClient(AbstractClient):
Returns:
models (List[LLMConfig]): List of LLM models
"""
response = requests.get(f"{self.base_url}/api/config/llm", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list models: {response.text}")
return [LLMConfig(**model) for model in response.json()]
@@ -1111,7 +1119,7 @@ class RESTClient(AbstractClient):
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
response = requests.get(f"{self.base_url}/api/config/embedding", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list embedding models: {response.text}")
return [EmbeddingConfig(**model) for model in response.json()]
@@ -1128,7 +1136,7 @@ class RESTClient(AbstractClient):
Returns:
id (str): ID of the tool (`None` if not found)
"""
response = requests.get(f"{self.base_url}/api/tools/name/{tool_name}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{tool_name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
@@ -1137,7 +1145,7 @@ class RESTClient(AbstractClient):
def create_tool(
self,
func,
func: Callable,
name: Optional[str] = None,
update: Optional[bool] = True, # TODO: actually use this
tags: Optional[List[str]] = None,
@@ -1157,14 +1165,21 @@ class RESTClient(AbstractClient):
# TODO: check tool update code
# TODO: check if tool already exists
# TODO: how to load modules?
# parse source code/schema
source_code = parse_source_code(func)
source_type = "python"
# TODO: Check if tool already exists
# if name:
# tool_id = self.get_tool_id(tool_name=name)
# if tool_id:
# raise ValueError(f"Tool with name {name} (id={tool_id}) already exists")
# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
response = requests.post(f"{self.base_url}/api/tools", json=request.model_dump(), headers=self.headers)
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return Tool(**response.json())
@@ -1173,7 +1188,7 @@ class RESTClient(AbstractClient):
self,
id: str,
name: Optional[str] = None,
func: Optional[callable] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
) -> Tool:
"""
@@ -1196,7 +1211,7 @@ class RESTClient(AbstractClient):
source_type = "python"
request = ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name)
response = requests.post(f"{self.base_url}/api/tools/{id}", json=request.model_dump(), headers=self.headers)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update tool: {response.text}")
return Tool(**response.json())
@@ -1235,7 +1250,7 @@ class RESTClient(AbstractClient):
# raise ValueError(f"Failed to create tool: {e}, invalid input {data}")
# # make REST request
# response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
# response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=data, headers=self.headers)
# if response.status_code != 200:
# raise ValueError(f"Failed to create tool: {response.text}")
# return ToolModel(**response.json())
@@ -1247,7 +1262,7 @@ class RESTClient(AbstractClient):
Returns:
tools (List[Tool]): List of tools
"""
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list tools: {response.text}")
return [Tool(**tool) for tool in response.json()]
@@ -1259,11 +1274,11 @@ class RESTClient(AbstractClient):
Args:
id (str): ID of the tool
"""
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
response = requests.delete(f"{self.base_url}/{self.api_prefix}/tools/{name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete tool: {response.text}")
def get_tool(self, name: str):
def get_tool(self, id: str) -> Optional[Tool]:
"""
Get a tool give its ID.
@@ -1273,13 +1288,30 @@ class RESTClient(AbstractClient):
Returns:
tool (Tool): Tool
"""
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/{id}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return Tool(**response.json())
def get_tool_id(self, name: str) -> Optional[str]:
"""
Get a tool ID by its name.
Args:
id (str): ID of the tool
Returns:
tool (Tool): Tool
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return response.json()
class LocalClient(AbstractClient):
"""
@@ -1972,7 +2004,7 @@ class LocalClient(AbstractClient):
tools = self.server.list_tools(user_id=self.user_id)
return tools
def get_tool(self, id: str) -> Tool:
def get_tool(self, id: str) -> Optional[Tool]:
"""
Get a tool give its ID.
@@ -2222,7 +2254,7 @@ class LocalClient(AbstractClient):
Returns:
models (List[LLMConfig]): List of LLM models
"""
return [self.server.server_llm_config]
return self.server.list_models()
def list_embedding_models(self) -> List[EmbeddingConfig]:
"""
@@ -2231,7 +2263,7 @@ class LocalClient(AbstractClient):
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
return [self.server.server_embedding_config]
return self.server.list_embedding_models()
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
"""

View File

@@ -744,7 +744,7 @@ class MetadataStore:
template: Optional[bool] = None,
name: Optional[str] = None,
id: Optional[str] = None,
) -> List[Block]:
) -> Optional[List[Block]]:
"""List available blocks"""
with self.session_maker() as session:
query = session.query(BlockModel)

View File

@@ -1,106 +0,0 @@
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from memgpt.schemas.source import Source
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
def list_agents(
user_id: str = Depends(get_current_user_with_server),
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
interface.clear()
agents_data = server.list_agents(user_id=user_id)
return agents_data
@router.post("/agents", tags=["agents"], response_model=AgentState)
def create_agent(
request: CreateAgent = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Create a new agent with the specified configuration.
"""
interface.clear()
agent_state = server.create_agent(request, user_id=user_id)
return agent_state
@router.post("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
def update_agent(
agent_id: str,
request: UpdateAgentState = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""Update an exsiting agent"""
interface.clear()
try:
# TODO: should id be moved out of UpdateAgentState?
agent_state = server.update_agent(request, user_id=user_id)
except Exception as e:
print(str(e))
raise HTTPException(status_code=500, detail=str(e))
return agent_state
@router.get("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
def get_agent_state(
agent_id: str = None,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get the state of the agent.
"""
interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
return server.get_agent_state(user_id=user_id, agent_id=agent_id)
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
# agent_id = str(agent_id)
interface.clear()
try:
server.delete_agent(user_id=user_id, agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/agents/{agent_id}/sources", tags=["agents"], response_model=List[Source])
def get_agent_sources(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get the sources associated with an agent.
"""
interface.clear()
return server.list_attached_sources(agent_id)
return router

View File

@@ -1,148 +0,0 @@
from functools import partial
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from memgpt.schemas.memory import (
ArchivalMemorySummary,
CreateArchivalMemory,
Memory,
RecallMemorySummary,
)
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/memory/messages", tags=["agents"], response_model=List[Message])
def get_agent_in_context_messages(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the messages in the context of a specific agent.
"""
interface.clear()
return server.get_in_context_messages(agent_id=agent_id)
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
def get_agent_memory(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
interface.clear()
return server.get_agent_memory(agent_id=agent_id)
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
interface.clear()
memory = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=request)
return memory
@router.get("/agents/{agent_id}/memory/recall", tags=["agents"], response_model=RecallMemorySummary)
def get_agent_recall_memory_summary(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the summary of the recall memory of a specific agent.
"""
interface.clear()
return server.get_recall_memory_summary(agent_id=agent_id)
@router.get("/agents/{agent_id}/memory/archival", tags=["agents"], response_model=ArchivalMemorySummary)
def get_agent_archival_memory_summary(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the summary of the archival memory of a specific agent.
"""
interface.clear()
return server.get_archival_memory_summary(agent_id=agent_id)
# @router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=List[Passage])
# def get_agent_archival_memory_all(
# agent_id: str,
# user_id: str = Depends(get_current_user_with_server),
# ):
# """
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
# """
# interface.clear()
# return server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
def get_agent_archival_memory(
agent_id: str,
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
interface.clear()
return server.get_agent_archival_cursor(
user_id=user_id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
@router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Insert a memory into an agent's archival memory store.
"""
interface.clear()
return server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.text)
@router.delete("/agents/{agent_id}/archival/{memory_id}", tags=["agents"])
def delete_agent_archival_memory(
agent_id: str,
memory_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Delete a memory from an agent's archival memory store.
"""
# TODO: should probably return a `Passage`
interface.clear()
try:
server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@@ -1,197 +0,0 @@
import asyncio
from datetime import datetime
from functools import partial
from typing import List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.message import Message, UpdateMessage
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
from memgpt.server.rest_api.utils import sse_async_generator
from memgpt.server.server import SyncServer
from memgpt.utils import deduplicate
router = APIRouter()
# TODO: cpacker should check this file
# TODO: move this into server.py?
async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
role: MessageRole,
message: str,
stream_steps: bool,
stream_tokens: bool,
return_message_object: bool, # Should be True for Python Client, False for REST API
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
# related to whether or not we return `MemGPTMessage`s or `Message`s
) -> Union[StreamingResponse, MemGPTResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO: @charles is this the correct way to handle?
include_final_message = True
# determine role
if role == MessageRole.user:
message_func = server.user_message
elif role == MessageRole.system:
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
if not stream_steps and stream_tokens:
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
# For streaming response
try:
# TODO: move this logic into server.py
# Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood
memgpt_agent = server._get_or_load_agent(agent_id=agent_id)
streaming_interface = memgpt_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Enable token-streaming within the request if desired
streaming_interface.streaming_mode = stream_tokens
# "chatcompletion mode" does some remapping and ignores inner thoughts
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
# streaming_interface.allow_assistant_message = stream
# streaming_interface.function_call_legacy_mode = stream
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
)
if stream_steps:
if return_message_object:
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
raise NotImplementedError
# return a stream
return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream",
)
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, MemGPTMessage)
or isinstance(message, LegacyMemGPTMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
break
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
if return_message_object:
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
return MemGPTResponse(messages=message_objs, usage=usage)
else:
return MemGPTResponse(messages=filtered_stream, usage=usage)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message])
def get_agent_messages(
agent_id: str,
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."),
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve message history for an agent.
"""
interface.clear()
return server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
before=before,
limit=limit,
reverse=True,
return_message_object=msg_object,
)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse)
async def send_message(
# background_tasks: BackgroundTasks,
agent_id: str,
request: MemGPTRequest = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
# TODO: should this recieve multiple messages? @cpacker
# TODO: revise to `MemGPTRequest`
# TODO: support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
message = request.messages[0]
# TODO: what to do with message.name?
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=user_id,
role=message.role,
message=message.text,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
)
@router.patch("/agents/{agent_id}/messages/{message_id}", tags=["agents"], response_model=Message)
async def update_message(
agent_id: str,
message_id: str,
request: UpdateMessage = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Update the details of a message associated with an agent.
"""
assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
return server.update_agent_message(agent_id=agent_id, request=request)
return router

View File

@@ -1,73 +0,0 @@
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.block import Block, CreateBlock
from memgpt.schemas.block import Human as HumanModel # TODO: modify
from memgpt.schemas.block import UpdateBlock
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_block_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/blocks", tags=["block"], response_model=List[Block])
async def list_blocks(
user_id: str = Depends(get_current_user_with_server),
# query parameters
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
):
# Clear the interface
interface.clear()
blocks = server.get_blocks(user_id=user_id, label=label, template=templates_only, name=name)
if blocks is None:
return []
return blocks
@router.post("/blocks", tags=["block"], response_model=Block)
async def create_block(
request: CreateBlock = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
request.user_id = user_id # TODO: remove?
return server.create_block(user_id=user_id, request=request)
@router.post("/blocks/{block_id}", tags=["block"], response_model=Block)
async def update_block(
block_id: str,
request: UpdateBlock = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
# TODO: should this be in the param or the POST data?
assert block_id == request.id
return server.update_block(request)
@router.delete("/blocks/{block_id}", tags=["block"], response_model=Block)
async def delete_block(
block_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.delete_block(block_id=block_id)
@router.get("/blocks/{block_id}", tags=["block"], response_model=Block)
async def get_block(
block_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
block = server.get_block(block_id=block_id)
if block is None:
raise HTTPException(status_code=404, detail="Block not found")
return block
return router

View File

@@ -1,40 +0,0 @@
from functools import partial
from typing import List
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class ConfigResponse(BaseModel):
config: dict = Field(..., description="The server configuration object.")
defaults: dict = Field(..., description="The defaults for the configuration.")
def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/config/llm", tags=["config"], response_model=List[LLMConfig])
def get_llm_configs(user_id: str = Depends(get_current_user_with_server)):
"""
Retrieve the base configuration for the server.
"""
interface.clear()
return [server.server_llm_config]
@router.get("/config/embedding", tags=["config"], response_model=List[EmbeddingConfig])
def get_embedding_configs(user_id: str = Depends(get_current_user_with_server)):
"""
Retrieve the base configuration for the server.
"""
interface.clear()
return [server.server_embedding_config]
return router

View File

@@ -1,41 +0,0 @@
from functools import partial
from typing import List
from fastapi import APIRouter, Depends
from memgpt.schemas.job import Job
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_jobs_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/jobs", tags=["jobs"], response_model=List[Job])
async def list_jobs(
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
# TODO: add filtering by status
return server.list_jobs(user_id=user_id)
@router.get("/jobs/active", tags=["jobs"], response_model=List[Job])
async def list_active_jobs(
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.list_active_jobs(user_id=user_id)
@router.get("/jobs/{job_id}", tags=["jobs"], response_model=Job)
async def get_job(
job_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.get_job(job_id=job_id)
return router

View File

@@ -1,38 +0,0 @@
from functools import partial
from typing import List
from fastapi import APIRouter
from pydantic import BaseModel, Field
from memgpt.schemas.llm_config import LLMConfig
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class ListModelsResponse(BaseModel):
models: List[LLMConfig] = Field(..., description="List of model configurations.")
def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str):
partial(partial(get_current_user, server), password)
@router.get("/models", tags=["models"], response_model=ListModelsResponse)
async def list_models():
# Clear the interface
interface.clear()
# currently, the server only supports one model, however this may change in the future
llm_config = LLMConfig(
model=server.server_llm_config.model,
model_endpoint=server.server_llm_config.model_endpoint,
model_endpoint_type=server.server_llm_config.model_endpoint_type,
model_wrapper=server.server_llm_config.model_wrapper,
context_window=server.server_llm_config.context_window,
)
return ListModelsResponse(models=[llm_config])
return router

View File

@@ -1,488 +0,0 @@
import uuid
from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Path, Query
from pydantic import BaseModel, Field
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.message import Message
from memgpt.schemas.openai.openai import (
AssistantFile,
MessageFile,
MessageRoleType,
OpenAIAssistant,
OpenAIMessage,
OpenAIRun,
OpenAIRunStep,
OpenAIThread,
Text,
ToolCall,
ToolCallOutput,
)
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.utils import get_utc_time
router = APIRouter()
class CreateAssistantRequest(BaseModel):
model: str = Field(..., description="The model to use for the assistant.")
name: str = Field(..., description="The name of the assistant.")
description: str = Field(None, description="The description of the assistant.")
instructions: str = Field(..., description="The instructions for the assistant.")
tools: List[str] = Field(None, description="The tools used by the assistant.")
file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.")
metadata: dict = Field(None, description="Metadata associated with the assistant.")
# memgpt-only (not openai)
embedding_model: str = Field(None, description="The model to use for the assistant.")
## TODO: remove
# user_id: str = Field(..., description="The unique identifier of the user.")
class CreateThreadRequest(BaseModel):
messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.")
# memgpt-only
assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)")
class ModifyThreadRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the thread.")
class ModifyMessageRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the message.")
class ModifyRunRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the run.")
class CreateMessageRequest(BaseModel):
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
content: str = Field(..., description="The message content to be processed by the agent.")
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the message.")
class UserMessageRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
class GetAgentMessagesRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class ListMessagesResponse(BaseModel):
messages: List[OpenAIMessage] = Field(..., description="List of message objects.")
class CreateAssistantFileRequest(BaseModel):
file_id: str = Field(..., description="The unique identifier of the file.")
class CreateRunRequest(BaseModel):
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
model: Optional[str] = Field(None, description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
class CreateThreadRunRequest(BaseModel):
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
thread: OpenAIThread = Field(..., description="The thread to run.")
model: str = Field(..., description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
class DeleteAssistantResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
object: str = "assistant.deleted"
deleted: bool = Field(..., description="Whether the agent was deleted.")
class DeleteAssistantFileResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the file.")
object: str = "assistant.file.deleted"
deleted: bool = Field(..., description="Whether the file was deleted.")
class DeleteThreadResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
object: str = "thread.deleted"
deleted: bool = Field(..., description="Whether the agent was deleted.")
class SubmitToolOutputsToRunRequest(BaseModel):
tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.")
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
# create assistant (MemGPT agent)
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# TODO: create preset
return OpenAIAssistant(
id=DEFAULT_PRESET,
name="default_preset",
description=request.description,
created_at=int(get_utc_time().timestamp()),
model=request.model,
instructions=request.instructions,
tools=request.tools,
file_ids=request.file_ids,
metadata=request.metadata,
)
@router.post("/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
def create_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantFileRequest = Body(...),
):
# TODO: add file to assistant
return AssistantFile(
id=request.file_id,
created_at=int(get_utc_time().timestamp()),
assistant_id=assistant_id,
)
@router.get("/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
def list_assistants(
limit: int = Query(1000, description="How many assistants to retrieve."),
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: implement list assistants (i.e. list available MemGPT presets)
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
def list_assistant_files(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
limit: int = Query(1000, description="How many files to retrieve."),
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: list attached data sources to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def retrieve_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: get and return preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
def retrieve_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: return data source attached to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def modify_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantRequest = Body(...),
):
# TODO: modify preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
def delete_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: delete preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: delete source on preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads", tags=["threads"], response_model=OpenAIThread)
def create_thread(request: CreateThreadRequest = Body(...)):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
print("Create thread/agent", request)
# create a memgpt agent
agent_state = server.create_agent(
user_id=user_id,
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
)
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
agent = server.get_agent(uuid.UUID(thread_id))
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
)
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/threads/{thread_id}", tags=["threads"], response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
):
agent_id = uuid.UUID(thread_id)
# create message object
message = Message(
user_id=user_id,
agent_id=agent_id,
role=request.role,
text=request.content,
)
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
return openai_message
@router.get("/threads/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
after_uuid = uuid.UUID(after) if before else None
before_uuid = uuid.UUID(before) if before else None
agent_id = uuid.UUID(thread_id)
reverse = True if (order == "desc") else False
cursor, json_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
)
print(json_messages[0]["text"])
# convert to openai style messages
openai_messages = [
OpenAIMessage(
id=str(message["id"]),
created_at=int(message["created_at"].timestamp()),
content=[Text(text=message["text"])],
role=message["role"],
thread_id=str(message["agent_id"]),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
for message in json_messages
]
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
router.get("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
):
message_id = uuid.UUID(message_id)
agent_id = uuid.UUID(thread_id)
message = server.get_agent_message(agent_id, message_id)
return OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
):
# TODO: add request.instructions as a message?
agent_id = uuid.UUID(thread_id)
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
agent.step(user_message=None) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@router.post("/threads/runs", tags=["runs"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
return router

View File

@@ -1,129 +0,0 @@
import json
import uuid
from functools import partial
from fastapi import APIRouter, Body, Depends, HTTPException
# from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
Choice,
Message,
UsageStatistics,
)
from memgpt.server.rest_api.agents.message import send_message_to_agent
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.utils import get_utc_time
router = APIRouter()
def setup_openai_chat_completions_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.post("/chat/completions", tags=["chat_completions"], response_model=ChatCompletionResponse)
async def create_chat_completion(
request: ChatCompletionRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Send a message to a MemGPT agent via a /chat/completions request
The bearer token will be used to identify the user.
The 'user' field in the request should be set to the agent ID.
"""
agent_id = request.user
if agent_id is None:
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
try:
agent_id = uuid.UUID(agent_id)
except:
raise HTTPException(status_code=400, detail="agent_id (in the 'user' field) must be a valid UUID")
messages = request.messages
if messages is None:
raise HTTPException(status_code=400, detail="'messages' field must not be empty")
if len(messages) > 1:
raise HTTPException(status_code=400, detail="'messages' field must be a list of length 1")
if messages[0].role != "user":
raise HTTPException(status_code=400, detail="'messages[0].role' must be a 'user'")
input_message = request.messages[0]
if request.stream:
print("Starting streaming OpenAI proxy response")
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=user_id,
role=input_message.role,
message=str(input_message.content),
stream_legacy=False,
# Turn streaming ON
stream_steps=True,
stream_tokens=True,
# Turn on ChatCompletion mode (eg remaps send_message to content)
chat_completion_mode=True,
)
else:
print("Starting non-streaming OpenAI proxy response")
response_messages = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=user_id,
role=input_message.role,
message=str(input_message.content),
stream_legacy=False,
# Turn streaming OFF
stream_steps=False,
stream_tokens=False,
)
# print(response_messages)
# Concatenate all send_message outputs together
id = ""
visible_message_str = ""
created_at = None
for memgpt_msg in response_messages.messages:
if "function_call" in memgpt_msg:
memgpt_function_call = memgpt_msg["function_call"]
if "name" in memgpt_function_call and memgpt_function_call["name"] == "send_message":
try:
memgpt_function_call_args = json.loads(memgpt_function_call["arguments"])
visible_message_str += memgpt_function_call_args["message"]
id = memgpt_function_call["id"]
created_at = memgpt_msg["date"]
except:
print(f"Failed to parse MemGPT message: {str(memgpt_function_call)}")
else:
print(f"Skipping function_call: {str(memgpt_function_call)}")
else:
print(f"Skipping message: {str(memgpt_msg)}")
response = ChatCompletionResponse(
id=id,
created=created_at if created_at else get_utc_time(),
choices=[
Choice(
finish_reason="stop",
index=0,
message=Message(
role="assistant",
content=visible_message_str,
),
)
],
# TODO add real usage
usage=UsageStatistics(
completion_tokens=0,
prompt_tokens=0,
total_tokens=0,
),
)
return response
return router

View File

@@ -1,70 +0,0 @@
import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.block import Persona as PersonaModel # TODO: modify
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class ListPersonasResponse(BaseModel):
personas: List[PersonaModel] = Field(..., description="List of persona configurations.")
class CreatePersonaRequest(BaseModel):
text: str = Field(..., description="The persona text.")
name: str = Field(..., description="The name of the persona.")
def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/personas", tags=["personas"], response_model=ListPersonasResponse)
async def list_personas(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# Clear the interface
interface.clear()
personas = server.ms.list_personas(user_id=user_id)
return ListPersonasResponse(personas=personas)
@router.post("/personas", tags=["personas"], response_model=PersonaModel)
async def create_persona(
request: CreatePersonaRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# TODO: disallow duplicate names for personas
interface.clear()
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
persona_id = new_persona.id
server.ms.create_persona(new_persona)
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
async def delete_persona(
persona_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
persona = server.ms.delete_persona(persona_name, user_id=user_id)
return persona
@router.get("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
async def get_persona(
persona_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
persona = server.ms.get_persona(persona_name, user_id=user_id)
if persona is None:
raise HTTPException(status_code=404, detail="Persona not found")
return persona
return router

View File

@@ -0,0 +1,115 @@
from typing import List
from fastapi import APIRouter, Body, HTTPException, Path, Query
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.openai.openai import AssistantFile, OpenAIAssistant
from memgpt.server.rest_api.routers.openai.assistants.schemas import (
CreateAssistantFileRequest,
CreateAssistantRequest,
DeleteAssistantFileResponse,
DeleteAssistantResponse,
)
from memgpt.utils import get_utc_time
router = APIRouter()
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
router = APIRouter(prefix="/v1/assistants", tags=["assistants"])
# create assistant (MemGPT agent)
@router.post("/", response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# TODO: create preset
return OpenAIAssistant(
id=DEFAULT_PRESET,
name="default_preset",
description=request.description,
created_at=int(get_utc_time().timestamp()),
model=request.model,
instructions=request.instructions,
tools=request.tools,
file_ids=request.file_ids,
metadata=request.metadata,
)
@router.post("/{assistant_id}/files", response_model=AssistantFile)
def create_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantFileRequest = Body(...),
):
# TODO: add file to assistant
return AssistantFile(
id=request.file_id,
created_at=int(get_utc_time().timestamp()),
assistant_id=assistant_id,
)
@router.get("/", response_model=List[OpenAIAssistant])
def list_assistants(
limit: int = Query(1000, description="How many assistants to retrieve."),
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: implement list assistants (i.e. list available MemGPT presets)
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{assistant_id}/files", response_model=List[AssistantFile])
def list_assistant_files(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
limit: int = Query(1000, description="How many files to retrieve."),
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: list attached data sources to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{assistant_id}", response_model=OpenAIAssistant)
def retrieve_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: get and return preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{assistant_id}/files/{file_id}", response_model=AssistantFile)
def retrieve_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: return data source attached to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{assistant_id}", response_model=OpenAIAssistant)
def modify_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantRequest = Body(...),
):
# TODO: modify preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/{assistant_id}", response_model=DeleteAssistantResponse)
def delete_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: delete preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/{assistant_id}/files/{file_id}", response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: delete source on preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")

View File

@@ -0,0 +1,121 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from memgpt.schemas.openai.openai import (
MessageRoleType,
OpenAIMessage,
OpenAIThread,
ToolCall,
ToolCallOutput,
)
class CreateAssistantRequest(BaseModel):
model: str = Field(..., description="The model to use for the assistant.")
name: str = Field(..., description="The name of the assistant.")
description: str = Field(None, description="The description of the assistant.")
instructions: str = Field(..., description="The instructions for the assistant.")
tools: List[str] = Field(None, description="The tools used by the assistant.")
file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.")
metadata: dict = Field(None, description="Metadata associated with the assistant.")
# memgpt-only (not openai)
embedding_model: str = Field(None, description="The model to use for the assistant.")
## TODO: remove
# user_id: str = Field(..., description="The unique identifier of the user.")
class CreateThreadRequest(BaseModel):
messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.")
# memgpt-only
assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)")
class ModifyThreadRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the thread.")
class ModifyMessageRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the message.")
class ModifyRunRequest(BaseModel):
metadata: dict = Field(None, description="Metadata associated with the run.")
class CreateMessageRequest(BaseModel):
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
content: str = Field(..., description="The message content to be processed by the agent.")
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the message.")
class UserMessageRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
class GetAgentMessagesRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class ListMessagesResponse(BaseModel):
messages: List[OpenAIMessage] = Field(..., description="List of message objects.")
class CreateAssistantFileRequest(BaseModel):
file_id: str = Field(..., description="The unique identifier of the file.")
class CreateRunRequest(BaseModel):
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
model: Optional[str] = Field(None, description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
class CreateThreadRunRequest(BaseModel):
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
thread: OpenAIThread = Field(..., description="The thread to run.")
model: str = Field(..., description="The model used by the run.")
instructions: str = Field(..., description="The instructions for the run.")
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
class DeleteAssistantResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
object: str = "assistant.deleted"
deleted: bool = Field(..., description="Whether the agent was deleted.")
class DeleteAssistantFileResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the file.")
object: str = "assistant.file.deleted"
deleted: bool = Field(..., description="Whether the file was deleted.")
class DeleteThreadResponse(BaseModel):
id: str = Field(..., description="The unique identifier of the agent.")
object: str = "thread.deleted"
deleted: bool = Field(..., description="Whether the agent was deleted.")
class SubmitToolOutputsToRunRequest(BaseModel):
tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.")

View File

@@ -0,0 +1,336 @@
import uuid
from typing import TYPE_CHECKING, List
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.agent import CreateAgent
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.message import Message
from memgpt.schemas.openai.openai import (
MessageFile,
OpenAIMessage,
OpenAIRun,
OpenAIRunStep,
OpenAIThread,
Text,
)
from memgpt.server.rest_api.routers.openai.assistants.schemas import (
CreateMessageRequest,
CreateRunRequest,
CreateThreadRequest,
CreateThreadRunRequest,
DeleteThreadResponse,
ListMessagesResponse,
ModifyMessageRequest,
ModifyRunRequest,
ModifyThreadRequest,
OpenAIThread,
SubmitToolOutputsToRunRequest,
)
from memgpt.server.rest_api.utils import get_memgpt_server
from memgpt.server.server import SyncServer
if TYPE_CHECKING:
from memgpt.utils import get_utc_time
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
router = APIRouter(prefix="/v1/threads", tags=["threads"])
@router.post("/", response_model=OpenAIThread)
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_memgpt_server),
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
actor = server.get_current_user()
print("Create thread/agent", request)
# create a memgpt agent
agent_state = server.create_agent(
request=CreateAgent(),
user_id=actor.id,
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
metadata={}, # TODO add metadata?
)
@router.get("/{thread_id}", response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_memgpt_server),
):
actor = server.get_current_user()
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
assert agent is not None
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
metadata={}, # TODO add metadata?
)
@router.get("/{thread_id}", response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/{thread_id}", response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_memgpt_server),
):
actor = server.get_current_user()
agent_id = thread_id
# create message object
message = Message(
user_id=actor.id,
agent_id=agent_id,
role=MessageRole(request.role),
text=request.content,
model=None,
tool_calls=None,
tool_call_id=None,
name=None,
)
agent = server._get_or_load_agent(agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
return openai_message
@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_memgpt_server),
):
actor = server.get_current_user()
after_uuid = after if before else None
before_uuid = before if before else None
agent_id = thread_id
reverse = True if (order == "desc") else False
json_messages = server.get_agent_recall_cursor(
user_id=actor.id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
return_message_object=True,
)
assert isinstance(json_messages, List)
assert all([isinstance(message, Message) for message in json_messages])
assert isinstance(json_messages[0], Message)
print(json_messages[0].text)
# convert to openai style messages
openai_messages = []
for message in json_messages:
assert isinstance(message, Message)
openai_messages.append(
OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=str(message.role),
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
)
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
server: SyncServer = Depends(get_memgpt_server),
):
agent_id = thread_id
message = server.get_agent_message(agent_id=agent_id, message_id=message_id)
assert message is not None
return OpenAIMessage(
id=message_id,
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
server: SyncServer = Depends(get_memgpt_server),
):
server.get_current_user()
# TODO: add request.instructions as a message?
agent_id = thread_id
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(agent_id=agent_id)
agent.step(user_message=None) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@router.post("/runs", tags=["runs"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")

View File

@@ -0,0 +1,131 @@
import json
from typing import TYPE_CHECKING
from fastapi import APIRouter, Body, Depends, HTTPException
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.memgpt_message import FunctionCall, MemGPTMessage
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
Choice,
Message,
UsageStatistics,
)
# TODO this belongs in a controller!
from memgpt.server.rest_api.routers.v1.agents import send_message_to_agent
from memgpt.server.rest_api.utils import get_memgpt_server
if TYPE_CHECKING:
pass
from memgpt.server.server import SyncServer
from memgpt.utils import get_utc_time
router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
@router.post("/", response_model=ChatCompletionResponse)
async def create_chat_completion(
completion_request: ChatCompletionRequest = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""Send a message to a MemGPT agent via a /chat/completions completion_request
The bearer token will be used to identify the user.
The 'user' field in the completion_request should be set to the agent ID.
"""
actor = server.get_current_user()
agent_id = completion_request.user
if agent_id is None:
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
messages = completion_request.messages
if messages is None:
raise HTTPException(status_code=400, detail="'messages' field must not be empty")
if len(messages) > 1:
raise HTTPException(status_code=400, detail="'messages' field must be a list of length 1")
if messages[0].role != "user":
raise HTTPException(status_code=400, detail="'messages[0].role' must be a 'user'")
input_message = completion_request.messages[0]
if completion_request.stream:
print("Starting streaming OpenAI proxy response")
# TODO(charles) support multimodal parts
assert isinstance(input_message.content, str)
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
role=MessageRole(input_message.role),
message=input_message.content,
# Turn streaming ON
stream_steps=True,
stream_tokens=True,
# Turn on ChatCompletion mode (eg remaps send_message to content)
chat_completion_mode=True,
return_message_object=False,
)
else:
print("Starting non-streaming OpenAI proxy response")
# TODO(charles) support multimodal parts
assert isinstance(input_message.content, str)
response_messages = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
role=MessageRole(input_message.role),
message=input_message.content,
# Turn streaming OFF
stream_steps=False,
stream_tokens=False,
return_message_object=False,
)
# print(response_messages)
# Concatenate all send_message outputs together
id = ""
visible_message_str = ""
created_at = None
for memgpt_msg in response_messages.messages:
assert isinstance(memgpt_msg, MemGPTMessage)
if isinstance(memgpt_msg, FunctionCall):
if memgpt_msg.name and memgpt_msg.name == "send_message":
try:
memgpt_function_call_args = json.loads(memgpt_msg.arguments)
visible_message_str += memgpt_function_call_args["message"]
id = memgpt_msg.id
created_at = memgpt_msg.date
except:
print(f"Failed to parse MemGPT message: {str(memgpt_msg)}")
else:
print(f"Skipping function_call: {str(memgpt_msg)}")
else:
print(f"Skipping message: {str(memgpt_msg)}")
response = ChatCompletionResponse(
id=id,
created=created_at if created_at else get_utc_time(),
choices=[
Choice(
finish_reason="stop",
index=0,
message=Message(
role="assistant",
content=visible_message_str,
),
)
],
# TODO add real usage
usage=UsageStatistics(
completion_tokens=0,
prompt_tokens=0,
total_tokens=0,
),
)
return response

View File

@@ -0,0 +1,15 @@
from memgpt.server.rest_api.routers.v1.agents import router as agents_router
from memgpt.server.rest_api.routers.v1.blocks import router as blocks_router
from memgpt.server.rest_api.routers.v1.jobs import router as jobs_router
from memgpt.server.rest_api.routers.v1.llms import router as llm_router
from memgpt.server.rest_api.routers.v1.sources import router as sources_router
from memgpt.server.rest_api.routers.v1.tools import router as tools_router
ROUTERS = [
tools_router,
sources_router,
agents_router,
llm_router,
blocks_router,
jobs_router,
]

View File

@@ -0,0 +1,529 @@
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.responses import StreamingResponse
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.memory import (
ArchivalMemorySummary,
CreateArchivalMemory,
Memory,
RecallMemorySummary,
)
from memgpt.schemas.message import Message, UpdateMessage
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.utils import get_memgpt_server, sse_async_generator
from memgpt.server.server import SyncServer
from memgpt.utils import deduplicate
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
router = APIRouter(prefix="/agents", tags=["agents"])
@router.get("/", response_model=List[AgentState])
def list_agents(
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
actor = server.get_current_user()
return server.list_agents(user_id=actor.id)
@router.post("/", response_model=AgentState)
def create_agent(
agent: CreateAgent = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Create a new agent with the specified configuration.
"""
actor = server.get_current_user()
agent.user_id = actor.id
return server.create_agent(agent, user_id=actor.id)
@router.patch("/{agent_id}", response_model=AgentState)
def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""Update an exsiting agent"""
actor = server.get_current_user()
update_agent.id = agent_id
return server.update_agent(update_agent, user_id=actor.id)
@router.get("/{agent_id}", response_model=AgentState)
def get_agent_state(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get the state of the agent.
"""
actor = server.get_current_user()
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
return server.get_agent_state(user_id=actor.id, agent_id=agent_id)
@router.delete("/{agent_id}")
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Delete an agent.
"""
actor = server.get_current_user()
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
@router.get("/{agent_id}/sources", response_model=List[Source])
def get_agent_sources(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get the sources associated with an agent.
"""
server.get_current_user()
return server.list_attached_sources(agent_id)
@router.get("/{agent_id}/memory/messages", response_model=List[Message])
def get_agent_in_context_messages(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Retrieve the messages in the context of a specific agent.
"""
return server.get_in_context_messages(agent_id=agent_id)
@router.get("/{agent_id}/memory", response_model=Memory)
def get_agent_memory(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
return server.get_agent_memory(agent_id=agent_id)
@router.patch("/{agent_id}/memory", response_model=Memory)
def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
actor = server.get_current_user()
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
return memory
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary)
def get_agent_recall_memory_summary(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Retrieve the summary of the recall memory of a specific agent.
"""
return server.get_recall_memory_summary(agent_id=agent_id)
@router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary)
def get_agent_archival_memory_summary(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Retrieve the summary of the archival memory of a specific agent.
"""
return server.get_archival_memory_summary(agent_id=agent_id)
@router.get("/{agent_id}/archival", response_model=List[Passage])
def get_agent_archival_memory(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.get_current_user()
# TODO need to add support for non-postgres here
# chroma will throw:
# raise ValueError("Cannot run get_all_cursor with chroma")
return server.get_agent_archival_cursor(
user_id=actor.id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
@router.post("/{agent_id}/archival", response_model=List[Passage])
def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.get_current_user()
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
# TODO(ethan): query or path parameter for memory_id?
# @router.delete("/{agent_id}/archival")
@router.delete("/{agent_id}/archival/{memory_id}")
def delete_agent_archival_memory(
agent_id: str,
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Delete a memory from an agent's archival memory store.
"""
actor = server.get_current_user()
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@router.get("/{agent_id}/messages", response_model=List[Message])
def get_agent_messages(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."),
):
"""
Retrieve message history for an agent.
"""
actor = server.get_current_user()
return server.get_agent_recall_cursor(
user_id=actor.id,
agent_id=agent_id,
before=before,
limit=limit,
reverse=True,
return_message_object=msg_object,
)
@router.patch("/{agent_id}/messages/{message_id}", response_model=Message)
def update_message(
agent_id: str,
message_id: str,
request: UpdateMessage = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Update the details of a message associated with an agent.
"""
assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
return server.update_agent_message(agent_id=agent_id, request=request)
@router.post("/{agent_id}/messages", response_model=MemGPTResponse)
async def send_message(
agent_id: str,
server: SyncServer = Depends(get_memgpt_server),
request: MemGPTRequest = Body(...),
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
"""
actor = server.get_current_user()
# TODO(charles): support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
message = request.messages[0]
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
role=message.role,
message=message.text,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
)
# TODO: move this into server.py?
async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
role: MessageRole,
message: str,
stream_steps: bool,
stream_tokens: bool,
return_message_object: bool, # Should be True for Python Client, False for REST API
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
# related to whether or not we return `MemGPTMessage`s or `Message`s
) -> Union[StreamingResponse, MemGPTResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO: @charles is this the correct way to handle?
include_final_message = True
# determine role
if role == MessageRole.user:
message_func = server.user_message
elif role == MessageRole.system:
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
if not stream_steps and stream_tokens:
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
# For streaming response
try:
# TODO: move this logic into server.py
# Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood
memgpt_agent = server._get_or_load_agent(agent_id=agent_id)
streaming_interface = memgpt_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Enable token-streaming within the request if desired
streaming_interface.streaming_mode = stream_tokens
# "chatcompletion mode" does some remapping and ignores inner thoughts
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
# streaming_interface.allow_assistant_message = stream
# streaming_interface.function_call_legacy_mode = stream
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
)
if stream_steps:
if return_message_object:
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
raise NotImplementedError
# return a stream
return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream",
)
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, MemGPTMessage)
or isinstance(message, LegacyMemGPTMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
break
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
if return_message_object:
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
return MemGPTResponse(messages=message_objs, usage=usage)
else:
return MemGPTResponse(messages=filtered_stream, usage=usage)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
##### MISSING #######
# @router.post("/{agent_id}/command")
# def run_command(
# agent_id: "UUID",
# command: "AgentCommandRequest",
#
# server: "SyncServer" = Depends(get_memgpt_server),
# ):
# """
# Execute a command on a specified agent.
# This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
# Raises an HTTPException for any processing errors.
# """
# actor = server.get_current_user()
#
# response = server.run_command(user_id=actor.id,
# agent_id=agent_id,
# command=command.command)
# return AgentCommandResponse(response=response)
# @router.get("/{agent_id}/config")
# def get_agent_config(
# agent_id: "UUID",
#
# server: "SyncServer" = Depends(get_memgpt_server),
# ):
# """
# Retrieve the configuration for a specific agent.
# This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
# """
# actor = server.get_current_user()
#
# if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
## agent does not exist
# raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
# agent_state = server.get_agent_config(user_id=actor.id, agent_id=agent_id)
## get sources
# attached_sources = server.list_attached_sources(agent_id=agent_id)
## configs
# llm_config = LLMConfig(**vars(agent_state.llm_config))
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
# return GetAgentResponse(
# agent_state=AgentState(
# id=agent_state.id,
# name=agent_state.name,
# user_id=agent_state.user_id,
# llm_config=llm_config,
# embedding_config=embedding_config,
# state=agent_state.state,
# created_at=int(agent_state.created_at.timestamp()),
# tools=agent_state.tools,
# system=agent_state.system,
# metadata=agent_state._metadata,
# ),
# last_run_at=None, # TODO
# sources=attached_sources,
# )
# @router.patch("/{agent_id}/rename", response_model=GetAgentResponse)
# def update_agent_name(
# agent_id: "UUID",
# agent_rename: AgentRenameRequest,
#
# server: "SyncServer" = Depends(get_memgpt_server),
# ):
# """
# Updates the name of a specific agent.
# This changes the name of the agent in the database but does NOT edit the agent's persona.
# """
# valid_name = agent_rename.agent_name
# actor = server.get_current_user()
#
# agent_state = server.rename_agent(user_id=actor.id, agent_id=agent_id, new_agent_name=valid_name)
## get sources
# attached_sources = server.list_attached_sources(agent_id=agent_id)
# llm_config = LLMConfig(**vars(agent_state.llm_config))
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
# return GetAgentResponse(
# agent_state=AgentState(
# id=agent_state.id,
# name=agent_state.name,
# user_id=agent_state.user_id,
# llm_config=llm_config,
# embedding_config=embedding_config,
# state=agent_state.state,
# created_at=int(agent_state.created_at.timestamp()),
# tools=agent_state.tools,
# system=agent_state.system,
# ),
# last_run_at=None, # TODO
# sources=attached_sources,
# )
# @router.get("/{agent_id}/archival/all", response_model=GetAgentArchivalMemoryResponse)
# def get_agent_archival_memory_all(
# agent_id: "UUID",
#
# server: "SyncServer" = Depends(get_memgpt_server),
# ):
# """
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
# """
# actor = server.get_current_user()
#
# archival_memories = server.get_all_archival_memories(user_id=actor.id, agent_id=agent_id)
# print("archival_memories:", archival_memories)
# archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
# return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)

View File

@@ -0,0 +1,73 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.block import Block, CreateBlock, UpdateBlock
from memgpt.server.rest_api.utils import get_memgpt_server
from memgpt.server.server import SyncServer
if TYPE_CHECKING:
pass
router = APIRouter(prefix="/blocks", tags=["blocks"])
@router.get("/", response_model=List[Block])
def list_blocks(
# query parameters
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
server: SyncServer = Depends(get_memgpt_server),
):
actor = server.get_current_user()
blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
if blocks is None:
return []
return blocks
@router.post("/", response_model=Block)
def create_block(
create_block: CreateBlock = Body(...),
server: SyncServer = Depends(get_memgpt_server),
):
actor = server.get_current_user()
create_block.user_id = actor.id
return server.create_block(user_id=actor.id, request=create_block)
@router.patch("/{block_id}", response_model=Block)
def update_block(
block_id: str,
updated_block: UpdateBlock = Body(...),
server: SyncServer = Depends(get_memgpt_server),
):
# actor = server.get_current_user()
updated_block.id = block_id
return server.update_block(request=updated_block)
# TODO: delete should not return anything
@router.delete("/{block_id}", response_model=Block)
def delete_block(
block_id: str,
server: SyncServer = Depends(get_memgpt_server),
):
return server.delete_block(block_id=block_id)
@router.get("/{block_id}", response_model=Block)
def get_block(
block_id: str,
server: SyncServer = Depends(get_memgpt_server),
):
block = server.get_block(block_id=block_id)
if block is None:
raise HTTPException(status_code=404, detail="Block not found")
return block

View File

@@ -0,0 +1,46 @@
from typing import List
from fastapi import APIRouter, Depends
from memgpt.schemas.job import Job
from memgpt.server.rest_api.utils import get_memgpt_server
from memgpt.server.server import SyncServer
router = APIRouter(prefix="/jobs", tags=["jobs"])
@router.get("/", response_model=List[Job])
def list_jobs(
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
List all jobs.
"""
actor = server.get_current_user()
# TODO: add filtering by status
return server.list_jobs(user_id=actor.id)
@router.get("/active", response_model=List[Job])
def list_active_jobs(
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
List all active jobs.
"""
actor = server.get_current_user()
return server.list_active_jobs(user_id=actor.id)
@router.get("/{job_id}", response_model=Job)
def get_job(
job_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get the status of a job.
"""
return server.get_job(job_id=job_id)

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, List
from fastapi import APIRouter, Depends
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.server.rest_api.utils import get_memgpt_server
if TYPE_CHECKING:
from memgpt.server.server import SyncServer
router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig])
def list_llm_backends(
server: "SyncServer" = Depends(get_memgpt_server),
):
return server.list_models()
@router.get("/embedding", response_model=List[EmbeddingConfig])
def list_embedding_backends(
server: "SyncServer" = Depends(get_memgpt_server),
):
return server.list_embedding_models()

View File

@@ -0,0 +1,199 @@
import os
import tempfile
from typing import List
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile
from memgpt.schemas.document import Document
from memgpt.schemas.job import Job
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.server.rest_api.utils import get_memgpt_server
from memgpt.server.server import SyncServer
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
router = APIRouter(prefix="/sources", tags=["sources"])
@router.get("/{source_id}", response_model=Source)
def get_source(
source_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get all sources
"""
actor = server.get_current_user()
return server.get_source(source_id=source_id, user_id=actor.id)
@router.get("/name/{source_name}", response_model=str)
def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get a source by name
"""
actor = server.get_current_user()
source = server.get_source_id(source_name=source_name, user_id=actor.id)
return source
@router.get("/", response_model=List[Source])
def list_sources(
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
List all data sources created by a user.
"""
actor = server.get_current_user()
return server.list_all_sources(user_id=actor.id)
@router.post("/", response_model=Source)
def create_source(
source: SourceCreate,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Create a new data source.
"""
actor = server.get_current_user()
return server.create_source(request=source, user_id=actor.id)
@router.patch("/{source_id}", response_model=Source)
def update_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Update the name or documentation of an existing data source.
"""
actor = server.get_current_user()
assert source.id == source_id, "Source ID in path must match ID in request body"
return server.update_source(request=source, user_id=actor.id)
@router.delete("/{source_id}")
def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Delete a data source.
"""
actor = server.get_current_user()
server.delete_source(source_id=source_id, user_id=actor.id)
@router.post("/{source_id}/attach", response_model=Source)
def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Attach a data source to an existing agent.
"""
actor = server.get_current_user()
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
assert source is not None, f"Source with id={source_id} not found."
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
return source
@router.post("/{source_id}/detach")
def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
server: "SyncServer" = Depends(get_memgpt_server),
) -> None:
"""
Detach a data source from an existing agent.
"""
actor = server.get_current_user()
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
@router.post("/{source_id}/upload", response_model=Job)
def upload_file_to_source(
file: UploadFile,
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Upload a file to a data source.
"""
actor = server.get_current_user()
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
assert source is not None, f"Source with id={source_id} not found."
bytes = file.file.read()
# create job
job = Job(
user_id=actor.id,
metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id},
completed_at=None,
)
job_id = job.id
server.ms.create_job(job)
# create background task
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
# return job information
job = server.ms.get_job(job_id=job_id)
assert job is not None, "Job not found"
return job
@router.get("/{source_id}/passages", response_model=List[Passage])
def list_passages(
source_id: str,
server: SyncServer = Depends(get_memgpt_server),
):
"""
List all passages associated with a data source.
"""
actor = server.get_current_user()
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
return passages
@router.get("/{source_id}/documents", response_model=List[Document])
def list_documents(
source_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
List all documents associated with a data source.
"""
actor = server.get_current_user()
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
return documents
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(str(tmpdirname), str(file.filename))
with open(file_path, "wb") as buffer:
buffer.write(bytes)
server.load_file_to_source(source_id, file_path, job_id)

View File

@@ -0,0 +1,103 @@
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
from memgpt.server.rest_api.utils import get_memgpt_server
from memgpt.server.server import SyncServer
router = APIRouter(prefix="/tools", tags=["tools"])
@router.delete("/{tool_id}")
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_memgpt_server),
):
"""
Delete a tool by name
"""
# actor = server.get_current_user()
server.delete_tool(tool_id=tool_id)
@router.get("/{tool_id}", response_model=Tool)
def get_tool(
tool_id: str,
server: SyncServer = Depends(get_memgpt_server),
):
"""
Get a tool by ID
"""
# actor = server.get_current_user()
tool = server.get_tool(tool_id=tool_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
return tool
@router.get("/name/{tool_name}", response_model=str)
def get_tool_id(
tool_name: str,
server: SyncServer = Depends(get_memgpt_server),
):
"""
Get a tool ID by name
"""
actor = server.get_current_user()
tool_id = server.get_tool_id(tool_name, user_id=actor.id)
if tool_id is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool_id
@router.get("/", response_model=List[Tool])
def list_all_tools(
server: SyncServer = Depends(get_memgpt_server),
):
"""
Get a list of all tools available to agents created by a user
"""
actor = server.get_current_user()
actor.id
# TODO: add back when user-specific
return server.list_tools(user_id=actor.id)
# return server.ms.list_tools(user_id=None)
@router.post("/", response_model=Tool)
def create_tool(
tool: ToolCreate = Body(...),
update: bool = False,
server: SyncServer = Depends(get_memgpt_server),
):
"""
Create a new tool
"""
actor = server.get_current_user()
return server.create_tool(
request=tool,
# update=update,
update=True,
user_id=actor.id,
)
@router.patch("/{tool_id}", response_model=Tool)
def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
server: SyncServer = Depends(get_memgpt_server),
):
"""
Update an existing tool
"""
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
server.get_current_user()
return server.update_tool(request)

View File

@@ -0,0 +1,109 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.api_key import APIKey, APIKeyCreate
from memgpt.schemas.user import User, UserCreate
from memgpt.server.rest_api.utils import get_memgpt_server
# from memgpt.server.schemas.users import (
# CreateAPIKeyRequest,
# CreateAPIKeyResponse,
# CreateUserRequest,
# CreateUserResponse,
# DeleteAPIKeyResponse,
# DeleteUserResponse,
# GetAllUsersResponse,
# GetAPIKeysResponse,
# )
if TYPE_CHECKING:
from memgpt.schemas.user import User
from memgpt.server.server import SyncServer
router = APIRouter(prefix="/users", tags=["users", "admin"])
@router.get("/", tags=["admin"], response_model=List[User])
def get_all_users(
cursor: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get a list of all users in the database
"""
try:
next_cursor, users = server.ms.get_all_users(cursor=cursor, limit=limit)
# processed_users = [{"user_id": user.id} for user in users]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return users
@router.post("/", tags=["admin"], response_model=User)
def create_user(
request: UserCreate = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Create a new user in the database
"""
user = server.create_user(request)
return user
@router.delete("/", tags=["admin"], response_model=User)
def delete_user(
user_id: str = Query(..., description="The user_id key to be deleted."),
server: "SyncServer" = Depends(get_memgpt_server),
):
# TODO make a soft deletion, instead of a hard deletion
try:
user = server.ms.get_user(user_id=user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
server.ms.delete_user(user_id=user_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user
@router.post("/keys", response_model=APIKey)
def create_new_api_key(
create_key: APIKeyCreate = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Create a new API key for a user
"""
api_key = server.create_api_key(create_key)
return api_key
@router.get("/keys", response_model=List[APIKey])
def get_api_keys(
user_id: str = Query(..., description="The unique identifier of the user."),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get a list of all API keys for a user
"""
if server.ms.get_user(user_id=user_id) is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id)
return api_keys
@router.delete("/keys", response_model=APIKey)
def delete_api_key(
api_key: str = Query(..., description="The API key to be deleted."),
server: "SyncServer" = Depends(get_memgpt_server),
):
return server.delete_api_key(api_key)

View File

@@ -12,27 +12,24 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette.middleware.cors import CORSMiddleware
from memgpt.server.constants import REST_DEFAULT_PORT
from memgpt.server.rest_api.admin.agents import setup_agents_admin_router
from memgpt.server.rest_api.admin.tools import setup_tools_index_router
from memgpt.server.rest_api.admin.users import setup_admin_router
from memgpt.server.rest_api.agents.index import setup_agents_index_router
from memgpt.server.rest_api.agents.memory import setup_agents_memory_router
from memgpt.server.rest_api.agents.message import setup_agents_message_router
from memgpt.server.rest_api.auth.index import setup_auth_router
from memgpt.server.rest_api.block.index import setup_block_index_router
from memgpt.server.rest_api.config.index import setup_config_index_router
from memgpt.server.rest_api.auth.index import (
setup_auth_router, # TODO: probably remove right?
)
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.jobs.index import setup_jobs_index_router
from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import (
setup_openai_assistant_router,
from memgpt.server.rest_api.routers.openai.assistants.assistants import (
router as openai_assistants_router,
)
from memgpt.server.rest_api.openai_chat_completions.chat_completions import (
setup_openai_chat_completions_router,
from memgpt.server.rest_api.routers.openai.assistants.threads import (
router as openai_threads_router,
)
from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import (
router as openai_chat_completions_router,
)
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_user_tools_index_router
from memgpt.server.server import SyncServer
from memgpt.settings import settings
@@ -44,8 +41,6 @@ Start the server with:
poetry run uvicorn server:app --reload
"""
# interface: QueuingInterface = QueuingInterface()
# interface: StreamingServerInterface = StreamingServerInterface()
interface: StreamingServerInterface = StreamingServerInterface
server: SyncServer = SyncServer(default_interface_factory=lambda: interface())
@@ -66,10 +61,9 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
raise HTTPException(status_code=401, detail="Unauthorized")
ADMIN_PREFIX = "/admin"
ADMIN_API_PREFIX = "/api/admin"
API_PREFIX = "/api"
OPENAI_API_PREFIX = "/v1"
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
app = FastAPI()
@@ -81,36 +75,51 @@ app.add_middleware(
allow_headers=["*"],
)
# v1_routes are the MemGPT API routes
for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
# we should always tie this to the newest version of the api.
app.include_router(route, prefix="", include_in_schema=False)
app.include_router(route, prefix="/latest", include_in_schema=False)
# admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX)
# openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
# /api/auth endpoints
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
# /admin/users endpoints
app.include_router(setup_admin_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)])
app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)])
# # Serve static files
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
# app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="static")
# /api/admin/agents endpoints
app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)])
# /api/agents endpoints
app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_jobs_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
# # Serve favicon
# @app.get("/favicon.ico")
# async def favicon():
# return FileResponse(os.path.join(static_files_path, "favicon.ico"))
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
# /v1/assistants endpoints
app.include_router(setup_openai_assistant_router(server, interface), prefix=OPENAI_API_PREFIX)
# # Middleware to handle API routes first
# @app.middleware("http")
# async def handle_api_routes(request: Request, call_next):
# if request.url.path.startswith(("/v1/", "/openai/")):
# response = await call_next(request)
# if response.status_code != 404:
# return response
# return await serve_spa(request.url.path)
# # Catch-all route for SPA
# async def serve_spa(full_path: str):
# return FileResponse(os.path.join(static_files_path, "index.html"))
# /v1/chat/completions endpoints
app.include_router(setup_openai_chat_completions_router(server, interface, password), prefix=OPENAI_API_PREFIX)
# / static files
mount_static_files(app)

View File

@@ -1,262 +0,0 @@
import os
import tempfile
from functools import partial
from typing import List
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
HTTPException,
Query,
UploadFile,
)
from memgpt.schemas.document import Document
from memgpt.schemas.job import Job
from memgpt.schemas.passage import Passage
# schemas
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate, UploadFile
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
"""
Implement the following functions:
* List all available sources
* Create a new source
* Delete a source
* Upload a file to a server that is loaded into a specific source
* Paginated get all passages from a source
* Paginated get all documents from a source
* Attach a source to an agent
"""
# class ListSourcesResponse(BaseModel):
# sources: List[SourceModel] = Field(..., description="List of available sources.")
#
#
# class CreateSourceRequest(BaseModel):
# name: str = Field(..., description="The name of the source.")
# description: Optional[str] = Field(None, description="The description of the source.")
#
#
# class UploadFileToSourceRequest(BaseModel):
# file: UploadFile = Field(..., description="The file to upload.")
#
#
# class UploadFileToSourceResponse(BaseModel):
# source: SourceModel = Field(..., description="The source the file was uploaded to.")
# added_passages: int = Field(..., description="The number of passages added to the source.")
# added_documents: int = Field(..., description="The number of documents added to the source.")
#
#
# class GetSourcePassagesResponse(BaseModel):
# passages: List[PassageModel] = Field(..., description="List of passages from the source.")
#
#
# class GetSourceDocumentsResponse(BaseModel):
# documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(bytes)
server.load_file_to_source(source_id, file_path, job_id)
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/sources/{source_id}", tags=["sources"], response_model=Source)
async def get_source(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get all sources
"""
interface.clear()
source = server.get_source(source_id=source_id, user_id=user_id)
return source
@router.get("/sources/name/{source_name}", tags=["sources"], response_model=str)
async def get_source_id_by_name(
source_name: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a source by name
"""
interface.clear()
source = server.get_source_id(source_name=source_name, user_id=user_id)
return source
@router.get("/sources", tags=["sources"], response_model=List[Source])
async def list_sources(
user_id: str = Depends(get_current_user_with_server),
):
"""
List all data sources created by a user.
"""
# Clear the interface
interface.clear()
try:
sources = server.list_all_sources(user_id=user_id)
return sources
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources", tags=["sources"], response_model=Source)
async def create_source(
request: SourceCreate = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Create a new data source.
"""
interface.clear()
try:
return server.create_source(request=request, user_id=user_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}", tags=["sources"], response_model=Source)
async def update_source(
source_id: str,
request: SourceUpdate = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Update the name or documentation of an existing data source.
"""
interface.clear()
try:
return server.update_source(request=request, user_id=user_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Delete a data source.
"""
interface.clear()
try:
server.delete_source(source_id=source_id, user_id=user_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=Source)
async def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
user_id: str = Depends(get_current_user_with_server),
):
"""
Attach a data source to an existing agent.
"""
interface.clear()
assert isinstance(agent_id, str), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, str), f"Expected user_id to be a UUID, got {user_id}"
source = server.ms.get_source(source_id=source_id, user_id=user_id)
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=user_id)
return source
@router.post("/sources/{source_id}/detach", tags=["sources"])
async def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
user_id: str = Depends(get_current_user_with_server),
):
"""
Detach a data source from an existing agent.
"""
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=Job)
async def get_job(
job_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get the status of a job.
"""
job = server.get_job(job_id=job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.")
return job
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=Job)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: str,
background_tasks: BackgroundTasks,
user_id: str = Depends(get_current_user_with_server),
):
"""
Upload a file to a data source.
"""
interface.clear()
source = server.ms.get_source(source_id=source_id, user_id=user_id)
bytes = file.file.read()
# create job
# TODO: create server function
job = Job(user_id=user_id, metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id})
job_id = job.id
server.ms.create_job(job)
# create background task
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
# return job information
job = server.ms.get_job(job_id=job_id)
return job
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=List[Passage])
async def list_passages(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
List all passages associated with a data source.
"""
# TODO: check if paginated?
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return passages
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=List[Document])
async def list_documents(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
List all documents associated with a data source.
"""
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
return documents
return router

View File

@@ -21,7 +21,8 @@ def mount_static_files(app: FastAPI):
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
if os.path.exists(static_files_path):
app.mount(
"/",
# "/",
"/app",
SPAStaticFiles(
directory=static_files_path,
html=True,

View File

@@ -1,98 +0,0 @@
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.delete("/tools/{tool_id}", tags=["tools"])
async def delete_tool(
tool_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Delete a tool by name
"""
# Clear the interface
interface.clear()
server.delete_tool(tool_id)
@router.get("/tools/{tool_id}", tags=["tools"], response_model=Tool)
async def get_tool(
tool_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
tool = server.get_tool(tool_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
return tool
@router.get("/tools/name/{tool_name}", tags=["tools"], response_model=str)
async def get_tool_id(
tool_name: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
tool = server.get_tool_id(tool_name, user_id=user_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool
@router.get("/tools", tags=["tools"], response_model=List[Tool])
async def list_all_tools(
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
return server.list_tools(user_id)
@router.post("/tools", tags=["tools"], response_model=Tool)
async def create_tool(
request: ToolCreate = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Create a new tool
"""
return server.create_tool(request, user_id=user_id)
@router.post("/tools/{tool_id}", tags=["tools"], response_model=Tool)
async def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Update an existing tool
"""
try:
# TODO: check that the user has access to this tool?
return server.update_tool(request)
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=f"Failed to update tool: {e}")
return router

View File

@@ -5,6 +5,9 @@ from typing import AsyncGenerator, Union
from pydantic import BaseModel
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.server import SyncServer
# from memgpt.orm.user import User
# from memgpt.orm.utilities import get_db_session
@@ -51,3 +54,13 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
if finish_message:
# Signal that the stream is complete
yield sse_formatter(SSE_FINISH_MSG)
# TODO: why does this double up the interface?
def get_memgpt_server() -> SyncServer:
server = SyncServer(default_interface_factory=lambda: StreamingServerInterface())
return server
def get_current_interface() -> StreamingServerInterface:
return StreamingServerInterface

View File

@@ -889,10 +889,10 @@ class SyncServer(Server):
self,
user_id: Optional[str] = None,
label: Optional[str] = None,
template: Optional[bool] = None,
template: bool = True,
name: Optional[str] = None,
id: Optional[str] = None,
):
) -> Optional[List[Block]]:
return self.ms.get_blocks(user_id=user_id, label=label, template=template, name=name, id=id)
@@ -1550,14 +1550,14 @@ class SyncServer(Server):
return sources_with_metadata
def get_tool(self, tool_id: str) -> Tool:
def get_tool(self, tool_id: str) -> Optional[Tool]:
"""Get tool by ID."""
return self.ms.get_tool(tool_id=tool_id)
def get_tool_id(self, name: str, user_id: str) -> str:
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
"""Get tool ID from name and user_id."""
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
if not tool:
if not tool or tool.id is None:
return None
return tool.id
@@ -1770,3 +1770,38 @@ class SyncServer(Server):
# Get the current message
memgpt_agent = self._get_or_load_agent(agent_id=agent_id)
return memgpt_agent.retry_message()
# TODO(ethan) wire back to real method in future ORM PR
def get_current_user(self) -> User:
"""Returns the currently authed user.
Since server is the core gateway this needs to pass through server as the
first touchpoint.
"""
# NOTE: same code as local client to get the default user
config = MemGPTConfig.load()
user_id = config.anon_clientid
user = self.get_user(user_id)
if not user:
user = self.create_user(UserCreate())
# # update config
config.anon_clientid = str(user.id)
config.save()
return user
def list_models(self) -> List[LLMConfig]:
"""List available models"""
# TODO support multiple models
llm_config = self.server_llm_config
return [llm_config]
def list_embedding_models(self) -> List[EmbeddingConfig]:
"""List available embedding models"""
# TODO support multiple models
embedding_config = self.server_embedding_config
return [embedding_config]

View File

@@ -2,12 +2,14 @@ import os
import threading
import time
import uuid
from typing import Union
import pytest
from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.agent import Agent
from memgpt.client.client import LocalClient, RESTClient
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.memory import ChatMemory
@@ -86,7 +88,7 @@ def agent(client):
client.delete_agent(agent_state.id)
def test_create_tool(client):
def test_create_tool(client: Union[LocalClient, RESTClient]):
"""Test creation of a simple tool"""
def print_tool(message: str):
@@ -111,7 +113,9 @@ def test_create_tool(client):
print(f"Updated tools {[t.name for t in tools]}")
# check tool id
tool = client.get_tool(tool.name)
tool = client.get_tool(tool.id)
assert tool is not None, "Expected tool to be created"
assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}"
# TODO: add back once we fix admin client tool creation