feat: add V1 route refactor from integration branch into separate PR (#1734)
This commit is contained in:
14
.github/workflows/black_format.yml
vendored
14
.github/workflows/black_format.yml
vendored
@@ -1,38 +1,36 @@
|
||||
name: Code Formatter (Black)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
black-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
with:
|
||||
ref: ${{ github.head_ref }} # Checkout the PR branch
|
||||
fetch-depth: 0 # Fetch all history for all branches and tags
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev" # TODO: change this to --group dev when PR #842 lands
|
||||
|
||||
- name: Run Black
|
||||
id: black
|
||||
run: poetry run black --check .
|
||||
continue-on-error: true
|
||||
|
||||
- name: Auto-fix with Black and commit
|
||||
if: steps.black.outcome == 'failure' || ${{ !contains(github.event.issue.labels.*.name, 'check-only') }}
|
||||
if: steps.black.outcome == 'failure' && !contains(github.event.pull_request.labels.*.name, 'check-only')
|
||||
run: |
|
||||
poetry run black .
|
||||
git config --local user.email "action@github.com"
|
||||
git config --local user.name "GitHub Action"
|
||||
git commit -am "Apply Black formatting" || echo "No changes to commit"
|
||||
git diff --quiet && git diff --staged --quiet || (git add -A && git commit -m "Apply Black formatting")
|
||||
git push
|
||||
- name: Error if 'check-only' label is present
|
||||
if: steps.black.outcome == 'failure' && contains(github.event.issue.labels.*.name, 'check-only')
|
||||
if: steps.black.outcome == 'failure' && contains(github.event.pull_request.labels.*.name, 'check-only')
|
||||
run: echo "Black formatting check failed. Please run 'black .' locally to fix formatting issues." && exit 1
|
||||
|
||||
33
.github/workflows/isort_format.yml
vendored
33
.github/workflows/isort_format.yml
vendored
@@ -1,39 +1,48 @@
|
||||
name: Code Formatter (isort)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
isort-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
with:
|
||||
ref: ${{ github.head_ref }} # Checkout the PR branch
|
||||
fetch-depth: 0 # Fetch all history for all branches and tags
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev" # TODO: change this to --group dev when PR #842 lands
|
||||
|
||||
- name: Run isort
|
||||
id: isort
|
||||
run: poetry run isort --profile black --check-only .
|
||||
run: |
|
||||
output=$(poetry run isort --profile black --check-only --diff . | grep -v "Skipped" || true)
|
||||
echo "$output"
|
||||
if [ -n "$output" ]; then
|
||||
echo "isort_changed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "isort_changed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
continue-on-error: true
|
||||
|
||||
- name: Auto-fix with isort and commit
|
||||
if: steps.isort.outcome == 'failure' || ${{ !contains(github.event.issue.labels.*.name, 'check-only') }}
|
||||
if: steps.isort.outputs.isort_changed == 'true' && !contains(github.event.pull_request.labels.*.name, 'check-only')
|
||||
run: |
|
||||
poetry run isort --profile black .
|
||||
git config --local user.email "action@github.com"
|
||||
git config --local user.name "GitHub Action"
|
||||
git commit -am "Apply isort import ordering" || echo "No changes to commit"
|
||||
git push
|
||||
- name: Error if 'check-only' label is present
|
||||
if: steps.isort.outcome == 'failure' && contains(github.event.issue.labels.*.name, 'check-only')
|
||||
if [[ -n $(git status -s) ]]; then
|
||||
git add -A
|
||||
git commit -m "Apply isort import ordering"
|
||||
git push
|
||||
else
|
||||
echo "No changes to commit"
|
||||
fi
|
||||
- name: Error if 'check-only' label is present and changes are needed
|
||||
if: steps.isort.outputs.isort_changed == 'true' && contains(github.event.pull_request.labels.*.name, 'check-only')
|
||||
run: echo "Isort check failed. Please run 'isort .' locally to fix import orders." && exit 1
|
||||
|
||||
|
||||
@@ -16,8 +16,14 @@ class Admin:
|
||||
- Generating user keys
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, token: str):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
token: str,
|
||||
api_prefix: str = "v1",
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.api_prefix = api_prefix
|
||||
self.token = token
|
||||
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
|
||||
|
||||
@@ -27,35 +33,36 @@ class Admin:
|
||||
params["cursor"] = str(cursor)
|
||||
if limit:
|
||||
params["limit"] = limit
|
||||
response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/users", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return [User(**user) for user in response.json()]
|
||||
|
||||
def create_key(self, user_id: str, key_name: Optional[str] = None) -> APIKey:
|
||||
request = APIKeyCreate(user_id=user_id, name=key_name)
|
||||
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=request.model_dump())
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users/keys", headers=self.headers, json=request.model_dump())
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return APIKey(**response.json())
|
||||
|
||||
def get_keys(self, user_id: str) -> List[APIKey]:
|
||||
params = {"user_id": str(user_id)}
|
||||
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return [APIKey(**key) for key in response.json()]
|
||||
|
||||
def delete_key(self, api_key: str) -> APIKey:
|
||||
params = {"api_key": api_key}
|
||||
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return APIKey(**response.json())
|
||||
|
||||
def create_user(self, name: Optional[str] = None) -> User:
|
||||
request = UserCreate(name=name)
|
||||
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=request.model_dump())
|
||||
print("YYYYY Pinging", f"{self.base_url}/{self.api_prefix}/admin/users")
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
response_json = response.json()
|
||||
@@ -63,7 +70,7 @@ class Admin:
|
||||
|
||||
def delete_user(self, user_id: str) -> User:
|
||||
params = {"user_id": str(user_id)}
|
||||
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/users", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return User(**response.json())
|
||||
@@ -114,23 +121,23 @@ class Admin:
|
||||
CreateToolRequest(**data) # validate
|
||||
|
||||
# make REST request
|
||||
response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/tools", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
def list_tools(self):
|
||||
response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/tools", headers=self.headers)
|
||||
return ListToolsResponse(**response.json()).tools
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
response = requests.delete(f"{self.base_url}/admin/tools/{name}", headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/tools/{name}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def get_tool(self, name: str):
|
||||
response = requests.get(f"{self.base_url}/admin/tools/{name}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/tools/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Generator, List, Optional, Union
|
||||
from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -185,7 +185,7 @@ class AbstractClient(object):
|
||||
self,
|
||||
id: str,
|
||||
name: Optional[str] = None,
|
||||
func: Optional[callable] = None,
|
||||
func: Optional[Callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
@@ -271,6 +271,7 @@ class RESTClient(AbstractClient):
|
||||
self,
|
||||
base_url: str,
|
||||
token: str,
|
||||
api_prefix: str = "v1",
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -283,10 +284,11 @@ class RESTClient(AbstractClient):
|
||||
"""
|
||||
super().__init__(debug=debug)
|
||||
self.base_url = base_url
|
||||
self.api_prefix = api_prefix
|
||||
self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
|
||||
|
||||
def list_agents(self) -> List[AgentState]:
|
||||
response = requests.get(f"{self.base_url}/api/agents", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers)
|
||||
return [AgentState(**agent) for agent in response.json()]
|
||||
|
||||
def agent_exists(self, agent_id: str) -> bool:
|
||||
@@ -301,7 +303,7 @@ class RESTClient(AbstractClient):
|
||||
exists (bool): `True` if the agent exists, `False` otherwise
|
||||
"""
|
||||
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
# not found error
|
||||
return False
|
||||
@@ -375,7 +377,7 @@ class RESTClient(AbstractClient):
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
|
||||
response = requests.post(f"{self.base_url}/api/agents", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
@@ -399,7 +401,7 @@ class RESTClient(AbstractClient):
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/api/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update message: {response.text}")
|
||||
@@ -448,7 +450,7 @@ class RESTClient(AbstractClient):
|
||||
message_ids=message_ids,
|
||||
memory=memory,
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
@@ -471,7 +473,7 @@ class RESTClient(AbstractClient):
|
||||
Args:
|
||||
agent_id (str): ID of the agent to delete
|
||||
"""
|
||||
response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/agents/{str(agent_id)}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete agent: {response.text}"
|
||||
|
||||
def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState:
|
||||
@@ -484,7 +486,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State representation of the agent
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to get agent: {response.text}"
|
||||
return AgentState(**response.json())
|
||||
|
||||
@@ -512,7 +514,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
memory (Memory): In-context memory of the agent
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get in-context memory: {response.text}")
|
||||
return Memory(**response.json())
|
||||
@@ -529,7 +531,9 @@ class RESTClient(AbstractClient):
|
||||
|
||||
"""
|
||||
memory_update_dict = {section: value}
|
||||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/memory", json=memory_update_dict, headers=self.headers)
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory", json=memory_update_dict, headers=self.headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update in-context memory: {response.text}")
|
||||
return Memory(**response.json())
|
||||
@@ -545,7 +549,7 @@ class RESTClient(AbstractClient):
|
||||
summary (ArchivalMemorySummary): Summary of the archival memory
|
||||
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/archival", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/archival", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get archival memory summary: {response.text}")
|
||||
return ArchivalMemorySummary(**response.json())
|
||||
@@ -560,7 +564,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
summary (RecallMemorySummary): Summary of the recall memory
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/recall", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/recall", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get recall memory summary: {response.text}")
|
||||
return RecallMemorySummary(**response.json())
|
||||
@@ -575,7 +579,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
messages (List[Message]): List of in-context messages
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory/messages", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/messages", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get in-context messages: {response.text}")
|
||||
return [Message(**message) for message in response.json()]
|
||||
@@ -620,7 +624,7 @@ class RESTClient(AbstractClient):
|
||||
params["before"] = str(before)
|
||||
if after:
|
||||
params["after"] = str(after)
|
||||
response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{str(agent_id)}/archival", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to get archival memory: {response.text}"
|
||||
return [Passage(**passage) for passage in response.json()]
|
||||
|
||||
@@ -636,7 +640,9 @@ class RESTClient(AbstractClient):
|
||||
passages (List[Passage]): List of inserted passages
|
||||
"""
|
||||
request = CreateArchivalMemory(text=memory)
|
||||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", headers=self.headers, json=request.model_dump())
|
||||
response = requests.post(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/archival", headers=self.headers, json=request.model_dump()
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to insert archival memory: {response.text}")
|
||||
return [Passage(**passage) for passage in response.json()]
|
||||
@@ -649,7 +655,7 @@ class RESTClient(AbstractClient):
|
||||
agent_id (str): ID of the agent
|
||||
memory_id (str): ID of the memory
|
||||
"""
|
||||
response = requests.delete(f"{self.base_url}/api/agents/{agent_id}/archival/{memory_id}", headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/archival/{memory_id}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete archival memory: {response.text}"
|
||||
|
||||
# messages (recall memory)
|
||||
@@ -671,7 +677,7 @@ class RESTClient(AbstractClient):
|
||||
"""
|
||||
|
||||
params = {"before": before, "after": after, "limit": limit, "msg_object": True}
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages", params=params, headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get messages: {response.text}")
|
||||
return [Message(**message) for message in response.json()]
|
||||
@@ -708,9 +714,11 @@ class RESTClient(AbstractClient):
|
||||
from memgpt.client.streaming import _sse_post
|
||||
|
||||
request.return_message_object = False
|
||||
return _sse_post(f"{self.base_url}/api/agents/{agent_id}/messages", request.model_dump(), self.headers)
|
||||
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", request.model_dump(), self.headers)
|
||||
else:
|
||||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to send message: {response.text}")
|
||||
return MemGPTResponse(**response.json())
|
||||
@@ -719,7 +727,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
|
||||
params = {"label": label, "templates_only": templates_only}
|
||||
response = requests.get(f"{self.base_url}/api/blocks", params=params, headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list blocks: {response.text}")
|
||||
|
||||
@@ -732,7 +740,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def create_block(self, label: str, name: str, text: str) -> Block: #
|
||||
request = CreateBlock(label=label, name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/api/blocks", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create block: {response.text}")
|
||||
if request.label == "human":
|
||||
@@ -744,13 +752,13 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
|
||||
request = UpdateBlock(id=block_id, name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/api/blocks/{block_id}", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update block: {response.text}")
|
||||
return Block(**response.json())
|
||||
|
||||
def get_block(self, block_id: str) -> Block:
|
||||
response = requests.get(f"{self.base_url}/api/blocks/{block_id}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
@@ -759,7 +767,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def get_block_id(self, name: str, label: str) -> str:
|
||||
params = {"name": name, "label": label}
|
||||
response = requests.get(f"{self.base_url}/api/blocks", params=params, headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get block ID: {response.text}")
|
||||
blocks = [Block(**block) for block in response.json()]
|
||||
@@ -770,7 +778,7 @@ class RESTClient(AbstractClient):
|
||||
return blocks[0].id
|
||||
|
||||
def delete_block(self, id: str) -> Block:
|
||||
response = requests.delete(f"{self.base_url}/api/blocks/{id}", headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/blocks/{id}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete block: {response.text}"
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete block: {response.text}")
|
||||
@@ -811,7 +819,7 @@ class RESTClient(AbstractClient):
|
||||
human (Human): Updated human block
|
||||
"""
|
||||
request = UpdateHuman(id=human_id, name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/api/blocks/{human_id}", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{human_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update human: {response.text}")
|
||||
return Human(**response.json())
|
||||
@@ -851,7 +859,7 @@ class RESTClient(AbstractClient):
|
||||
persona (Persona): Updated persona block
|
||||
"""
|
||||
request = UpdatePersona(id=persona_id, name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/api/blocks/{persona_id}", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{persona_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update persona: {response.text}")
|
||||
return Persona(**response.json())
|
||||
@@ -934,7 +942,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
source (Source): Source
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/sources/{source_id}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get source: {response.text}")
|
||||
return Source(**response.json())
|
||||
@@ -949,7 +957,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
source_id (str): ID of the source
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/sources/name/{source_name}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/name/{source_name}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get source ID: {response.text}")
|
||||
return response.json()
|
||||
@@ -961,7 +969,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
sources (List[Source]): List of sources
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list sources: {response.text}")
|
||||
return [Source(**source) for source in response.json()]
|
||||
@@ -973,21 +981,21 @@ class RESTClient(AbstractClient):
|
||||
Args:
|
||||
source_id (str): ID of the source
|
||||
"""
|
||||
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/sources/{str(source_id)}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete source: {response.text}"
|
||||
|
||||
def get_job(self, job_id: str) -> Job:
|
||||
response = requests.get(f"{self.base_url}/api/jobs/{job_id}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get job: {response.text}")
|
||||
return Job(**response.json())
|
||||
|
||||
def list_jobs(self):
|
||||
response = requests.get(f"{self.base_url}/api/jobs", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
|
||||
return [Job(**job) for job in response.json()]
|
||||
|
||||
def list_active_jobs(self):
|
||||
response = requests.get(f"{self.base_url}/api/jobs/active", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs/active", headers=self.headers)
|
||||
return [Job(**job) for job in response.json()]
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
@@ -1008,7 +1016,7 @@ class RESTClient(AbstractClient):
|
||||
files = {"file": open(filename, "rb")}
|
||||
|
||||
# create job
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/upload", files=files, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to upload file to source: {response.text}")
|
||||
|
||||
@@ -1035,7 +1043,7 @@ class RESTClient(AbstractClient):
|
||||
source (Source): Created source
|
||||
"""
|
||||
payload = {"name": name}
|
||||
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers)
|
||||
response_json = response.json()
|
||||
return Source(**response_json)
|
||||
|
||||
@@ -1049,7 +1057,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
sources (List[Source]): List of sources
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/sources", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list attached sources: {response.text}")
|
||||
return [Source(**source) for source in response.json()]
|
||||
@@ -1066,7 +1074,7 @@ class RESTClient(AbstractClient):
|
||||
source (Source): Updated source
|
||||
"""
|
||||
request = SourceUpdate(id=source_id, name=name)
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update source: {response.text}")
|
||||
return Source(**response.json())
|
||||
@@ -1081,13 +1089,13 @@ class RESTClient(AbstractClient):
|
||||
source_name (str): Name of the source
|
||||
"""
|
||||
params = {"agent_id": agent_id}
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/attach", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
||||
|
||||
def detach_source(self, source_id: str, agent_id: str):
|
||||
"""Detach a source from an agent"""
|
||||
params = {"agent_id": str(agent_id)}
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/detach", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||
|
||||
# server configuration commands
|
||||
@@ -1099,7 +1107,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
models (List[LLMConfig]): List of LLM models
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/config/llm", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list models: {response.text}")
|
||||
return [LLMConfig(**model) for model in response.json()]
|
||||
@@ -1111,7 +1119,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
models (List[EmbeddingConfig]): List of embedding models
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/config/embedding", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list embedding models: {response.text}")
|
||||
return [EmbeddingConfig(**model) for model in response.json()]
|
||||
@@ -1128,7 +1136,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
id (str): ID of the tool (`None` if not found)
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/tools/name/{tool_name}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{tool_name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
@@ -1137,7 +1145,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
func: Callable,
|
||||
name: Optional[str] = None,
|
||||
update: Optional[bool] = True, # TODO: actually use this
|
||||
tags: Optional[List[str]] = None,
|
||||
@@ -1157,14 +1165,21 @@ class RESTClient(AbstractClient):
|
||||
|
||||
# TODO: check tool update code
|
||||
# TODO: check if tool already exists
|
||||
|
||||
# TODO: how to load modules?
|
||||
# parse source code/schema
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
|
||||
# TODO: Check if tool already exists
|
||||
# if name:
|
||||
# tool_id = self.get_tool_id(tool_name=name)
|
||||
# if tool_id:
|
||||
# raise ValueError(f"Tool with name {name} (id={tool_id}) already exists")
|
||||
|
||||
# call server function
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
|
||||
response = requests.post(f"{self.base_url}/api/tools", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
@@ -1173,7 +1188,7 @@ class RESTClient(AbstractClient):
|
||||
self,
|
||||
id: str,
|
||||
name: Optional[str] = None,
|
||||
func: Optional[callable] = None,
|
||||
func: Optional[Callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
@@ -1196,7 +1211,7 @@ class RESTClient(AbstractClient):
|
||||
source_type = "python"
|
||||
|
||||
request = ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name)
|
||||
response = requests.post(f"{self.base_url}/api/tools/{id}", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
@@ -1235,7 +1250,7 @@ class RESTClient(AbstractClient):
|
||||
# raise ValueError(f"Failed to create tool: {e}, invalid input {data}")
|
||||
|
||||
# # make REST request
|
||||
# response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
|
||||
# response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=data, headers=self.headers)
|
||||
# if response.status_code != 200:
|
||||
# raise ValueError(f"Failed to create tool: {response.text}")
|
||||
# return ToolModel(**response.json())
|
||||
@@ -1247,7 +1262,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
tools (List[Tool]): List of tools
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list tools: {response.text}")
|
||||
return [Tool(**tool) for tool in response.json()]
|
||||
@@ -1259,11 +1274,11 @@ class RESTClient(AbstractClient):
|
||||
Args:
|
||||
id (str): ID of the tool
|
||||
"""
|
||||
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/tools/{name}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete tool: {response.text}")
|
||||
|
||||
def get_tool(self, name: str):
|
||||
def get_tool(self, id: str) -> Optional[Tool]:
|
||||
"""
|
||||
Get a tool give its ID.
|
||||
|
||||
@@ -1273,13 +1288,30 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
tool (Tool): Tool
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/{id}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
|
||||
def get_tool_id(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Get a tool ID by its name.
|
||||
|
||||
Args:
|
||||
id (str): ID of the tool
|
||||
|
||||
Returns:
|
||||
tool (Tool): Tool
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
"""
|
||||
@@ -1972,7 +2004,7 @@ class LocalClient(AbstractClient):
|
||||
tools = self.server.list_tools(user_id=self.user_id)
|
||||
return tools
|
||||
|
||||
def get_tool(self, id: str) -> Tool:
|
||||
def get_tool(self, id: str) -> Optional[Tool]:
|
||||
"""
|
||||
Get a tool give its ID.
|
||||
|
||||
@@ -2222,7 +2254,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
models (List[LLMConfig]): List of LLM models
|
||||
"""
|
||||
return [self.server.server_llm_config]
|
||||
return self.server.list_models()
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
"""
|
||||
@@ -2231,7 +2263,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
models (List[EmbeddingConfig]): List of embedding models
|
||||
"""
|
||||
return [self.server.server_embedding_config]
|
||||
return self.server.list_embedding_models()
|
||||
|
||||
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
|
||||
"""
|
||||
|
||||
@@ -744,7 +744,7 @@ class MetadataStore:
|
||||
template: Optional[bool] = None,
|
||||
name: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
) -> List[Block]:
|
||||
) -> Optional[List[Block]]:
|
||||
"""List available blocks"""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(BlockModel)
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from memgpt.schemas.source import Source
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
|
||||
def list_agents(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
interface.clear()
|
||||
agents_data = server.list_agents(user_id=user_id)
|
||||
return agents_data
|
||||
|
||||
@router.post("/agents", tags=["agents"], response_model=AgentState)
|
||||
def create_agent(
|
||||
request: CreateAgent = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
interface.clear()
|
||||
|
||||
agent_state = server.create_agent(request, user_id=user_id)
|
||||
return agent_state
|
||||
|
||||
@router.post("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
|
||||
def update_agent(
|
||||
agent_id: str,
|
||||
request: UpdateAgentState = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
interface.clear()
|
||||
try:
|
||||
# TODO: should id be moved out of UpdateAgentState?
|
||||
agent_state = server.update_agent(request, user_id=user_id)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return agent_state
|
||||
|
||||
@router.get("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
|
||||
def get_agent_state(
|
||||
agent_id: str = None,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
|
||||
interface.clear()
|
||||
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
return server.get_agent_state(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
@router.delete("/agents/{agent_id}", tags=["agents"])
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
# agent_id = str(agent_id)
|
||||
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_agent(user_id=user_id, agent_id=agent_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.get("/agents/{agent_id}/sources", tags=["agents"], response_model=List[Source])
|
||||
def get_agent_sources(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.list_attached_sources(agent_id)
|
||||
|
||||
return router
|
||||
@@ -1,148 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from memgpt.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
CreateArchivalMemory,
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}/memory/messages", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_in_context_messages(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the messages in the context of a specific agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_in_context_messages(agent_id=agent_id)
|
||||
|
||||
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
|
||||
def get_agent_memory(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_memory(agent_id=agent_id)
|
||||
|
||||
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
|
||||
def update_agent_memory(
|
||||
agent_id: str,
|
||||
request: Dict = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
|
||||
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
interface.clear()
|
||||
memory = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=request)
|
||||
return memory
|
||||
|
||||
@router.get("/agents/{agent_id}/memory/recall", tags=["agents"], response_model=RecallMemorySummary)
|
||||
def get_agent_recall_memory_summary(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the recall memory of a specific agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_recall_memory_summary(agent_id=agent_id)
|
||||
|
||||
@router.get("/agents/{agent_id}/memory/archival", tags=["agents"], response_model=ArchivalMemorySummary)
|
||||
def get_agent_archival_memory_summary(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the archival memory of a specific agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_archival_memory_summary(agent_id=agent_id)
|
||||
|
||||
# @router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=List[Passage])
|
||||
# def get_agent_archival_memory_all(
|
||||
# agent_id: str,
|
||||
# user_id: str = Depends(get_current_user_with_server),
|
||||
# ):
|
||||
# """
|
||||
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
|
||||
# """
|
||||
# interface.clear()
|
||||
# return server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
|
||||
def get_agent_archival_memory(
|
||||
agent_id: str,
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_archival_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
|
||||
def insert_agent_archival_memory(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.text)
|
||||
|
||||
@router.delete("/agents/{agent_id}/archival/{memory_id}", tags=["agents"])
|
||||
def delete_agent_archival_memory(
|
||||
agent_id: str,
|
||||
memory_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
# TODO: should probably return a `Passage`
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
return router
|
||||
@@ -1,197 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.memgpt_request import MemGPTRequest
|
||||
from memgpt.schemas.memgpt_response import MemGPTResponse
|
||||
from memgpt.schemas.message import Message, UpdateMessage
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
|
||||
from memgpt.server.rest_api.utils import sse_async_generator
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import deduplicate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# TODO: cpacker should check this file
|
||||
# TODO: move this into server.py?
|
||||
async def send_message_to_agent(
|
||||
server: SyncServer,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
role: MessageRole,
|
||||
message: str,
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
return_message_object: bool, # Should be True for Python Client, False for REST API
|
||||
chat_completion_mode: Optional[bool] = False,
|
||||
timestamp: Optional[datetime] = None,
|
||||
# related to whether or not we return `MemGPTMessage`s or `Message`s
|
||||
) -> Union[StreamingResponse, MemGPTResponse]:
|
||||
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
|
||||
# TODO: @charles is this the correct way to handle?
|
||||
include_final_message = True
|
||||
|
||||
# determine role
|
||||
if role == MessageRole.user:
|
||||
message_func = server.user_message
|
||||
elif role == MessageRole.system:
|
||||
message_func = server.system_message
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Bad role {role}")
|
||||
|
||||
if not stream_steps and stream_tokens:
|
||||
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
# For streaming response
|
||||
try:
|
||||
|
||||
# TODO: move this logic into server.py
|
||||
|
||||
# Get the generator object off of the agent's streaming interface
|
||||
# This will be attached to the POST SSE request used under-the-hood
|
||||
memgpt_agent = server._get_or_load_agent(agent_id=agent_id)
|
||||
streaming_interface = memgpt_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
|
||||
# Enable token-streaming within the request if desired
|
||||
streaming_interface.streaming_mode = stream_tokens
|
||||
# "chatcompletion mode" does some remapping and ignores inner thoughts
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
|
||||
# streaming_interface.allow_assistant_message = stream
|
||||
# streaming_interface.function_call_legacy_mode = stream
|
||||
|
||||
# Offload the synchronous message_func to a separate thread
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
if return_message_object:
|
||||
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
|
||||
raise NotImplementedError
|
||||
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, MemGPTMessage)
|
||||
or isinstance(message, LegacyMemGPTMessage)
|
||||
or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
break
|
||||
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
usage = await task
|
||||
|
||||
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
if return_message_object:
|
||||
message_ids = [m.id for m in filtered_stream]
|
||||
message_ids = deduplicate(message_ids)
|
||||
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
|
||||
return MemGPTResponse(messages=message_objs, usage=usage)
|
||||
else:
|
||||
return MemGPTResponse(messages=filtered_stream, usage=usage)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_messages(
|
||||
agent_id: str,
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
before=before,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
return_message_object=msg_object,
|
||||
)
|
||||
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse)
|
||||
async def send_message(
|
||||
# background_tasks: BackgroundTasks,
|
||||
agent_id: str,
|
||||
request: MemGPTRequest = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream' is set to True.
|
||||
"""
|
||||
# TODO: should this recieve multiple messages? @cpacker
|
||||
# TODO: revise to `MemGPTRequest`
|
||||
# TODO: support sending multiple messages
|
||||
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
||||
message = request.messages[0]
|
||||
|
||||
# TODO: what to do with message.name?
|
||||
return await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
role=message.role,
|
||||
message=message.text,
|
||||
stream_steps=request.stream_steps,
|
||||
stream_tokens=request.stream_tokens,
|
||||
return_message_object=request.return_message_object,
|
||||
)
|
||||
|
||||
@router.patch("/agents/{agent_id}/messages/{message_id}", tags=["agents"], response_model=Message)
|
||||
async def update_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
request: UpdateMessage = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
|
||||
return server.update_agent_message(agent_id=agent_id, request=request)
|
||||
|
||||
return router
|
||||
@@ -1,73 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from memgpt.schemas.block import Block, CreateBlock
|
||||
from memgpt.schemas.block import Human as HumanModel # TODO: modify
|
||||
from memgpt.schemas.block import UpdateBlock
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_block_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/blocks", tags=["block"], response_model=List[Block])
|
||||
async def list_blocks(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
# query parameters
|
||||
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
blocks = server.get_blocks(user_id=user_id, label=label, template=templates_only, name=name)
|
||||
if blocks is None:
|
||||
return []
|
||||
return blocks
|
||||
|
||||
@router.post("/blocks", tags=["block"], response_model=Block)
|
||||
async def create_block(
|
||||
request: CreateBlock = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
request.user_id = user_id # TODO: remove?
|
||||
return server.create_block(user_id=user_id, request=request)
|
||||
|
||||
@router.post("/blocks/{block_id}", tags=["block"], response_model=Block)
|
||||
async def update_block(
|
||||
block_id: str,
|
||||
request: UpdateBlock = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
# TODO: should this be in the param or the POST data?
|
||||
assert block_id == request.id
|
||||
return server.update_block(request)
|
||||
|
||||
@router.delete("/blocks/{block_id}", tags=["block"], response_model=Block)
|
||||
async def delete_block(
|
||||
block_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
return server.delete_block(block_id=block_id)
|
||||
|
||||
@router.get("/blocks/{block_id}", tags=["block"], response_model=Block)
|
||||
async def get_block(
|
||||
block_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
block = server.get_block(block_id=block_id)
|
||||
if block is None:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
return block
|
||||
|
||||
return router
|
||||
@@ -1,40 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ConfigResponse(BaseModel):
|
||||
config: dict = Field(..., description="The server configuration object.")
|
||||
defaults: dict = Field(..., description="The defaults for the configuration.")
|
||||
|
||||
|
||||
def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/config/llm", tags=["config"], response_model=List[LLMConfig])
|
||||
def get_llm_configs(user_id: str = Depends(get_current_user_with_server)):
|
||||
"""
|
||||
Retrieve the base configuration for the server.
|
||||
"""
|
||||
interface.clear()
|
||||
return [server.server_llm_config]
|
||||
|
||||
@router.get("/config/embedding", tags=["config"], response_model=List[EmbeddingConfig])
|
||||
def get_embedding_configs(user_id: str = Depends(get_current_user_with_server)):
|
||||
"""
|
||||
Retrieve the base configuration for the server.
|
||||
"""
|
||||
interface.clear()
|
||||
return [server.server_embedding_config]
|
||||
|
||||
return router
|
||||
@@ -1,41 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_jobs_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/jobs", tags=["jobs"], response_model=List[Job])
|
||||
async def list_jobs(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
|
||||
# TODO: add filtering by status
|
||||
return server.list_jobs(user_id=user_id)
|
||||
|
||||
@router.get("/jobs/active", tags=["jobs"], response_model=List[Job])
|
||||
async def list_active_jobs(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
return server.list_active_jobs(user_id=user_id)
|
||||
|
||||
@router.get("/jobs/{job_id}", tags=["jobs"], response_model=Job)
|
||||
async def get_job(
|
||||
job_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
return server.get_job(job_id=job_id)
|
||||
|
||||
return router
|
||||
@@ -1,38 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListModelsResponse(BaseModel):
|
||||
models: List[LLMConfig] = Field(..., description="List of model configurations.")
|
||||
|
||||
|
||||
def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/models", tags=["models"], response_model=ListModelsResponse)
|
||||
async def list_models():
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
# currently, the server only supports one model, however this may change in the future
|
||||
llm_config = LLMConfig(
|
||||
model=server.server_llm_config.model,
|
||||
model_endpoint=server.server_llm_config.model_endpoint,
|
||||
model_endpoint_type=server.server_llm_config.model_endpoint_type,
|
||||
model_wrapper=server.server_llm_config.model_wrapper,
|
||||
context_window=server.server_llm_config.context_window,
|
||||
)
|
||||
|
||||
return ListModelsResponse(models=[llm_config])
|
||||
|
||||
return router
|
||||
@@ -1,488 +0,0 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.openai import (
|
||||
AssistantFile,
|
||||
MessageFile,
|
||||
MessageRoleType,
|
||||
OpenAIAssistant,
|
||||
OpenAIMessage,
|
||||
OpenAIRun,
|
||||
OpenAIRunStep,
|
||||
OpenAIThread,
|
||||
Text,
|
||||
ToolCall,
|
||||
ToolCallOutput,
|
||||
)
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str = Field(..., description="The model to use for the assistant.")
|
||||
name: str = Field(..., description="The name of the assistant.")
|
||||
description: str = Field(None, description="The description of the assistant.")
|
||||
instructions: str = Field(..., description="The instructions for the assistant.")
|
||||
tools: List[str] = Field(None, description="The tools used by the assistant.")
|
||||
file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.")
|
||||
metadata: dict = Field(None, description="Metadata associated with the assistant.")
|
||||
|
||||
# memgpt-only (not openai)
|
||||
embedding_model: str = Field(None, description="The model to use for the assistant.")
|
||||
|
||||
## TODO: remove
|
||||
# user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
# memgpt-only
|
||||
assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)")
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class ModifyRunRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
|
||||
content: str = Field(..., description="The message content to be processed by the agent.")
|
||||
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
|
||||
|
||||
class UserMessageResponse(BaseModel):
|
||||
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
||||
|
||||
|
||||
class GetAgentMessagesRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
messages: List[OpenAIMessage] = Field(..., description="List of message objects.")
|
||||
|
||||
|
||||
class CreateAssistantFileRequest(BaseModel):
|
||||
file_id: str = Field(..., description="The unique identifier of the file.")
|
||||
|
||||
|
||||
class CreateRunRequest(BaseModel):
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
model: Optional[str] = Field(None, description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class CreateThreadRunRequest(BaseModel):
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
thread: OpenAIThread = Field(..., description="The thread to run.")
|
||||
model: str = Field(..., description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the agent.")
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool = Field(..., description="Whether the agent was deleted.")
|
||||
|
||||
|
||||
class DeleteAssistantFileResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the file.")
|
||||
object: str = "assistant.file.deleted"
|
||||
deleted: bool = Field(..., description="Whether the file was deleted.")
|
||||
|
||||
|
||||
class DeleteThreadResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the agent.")
|
||||
object: str = "thread.deleted"
|
||||
deleted: bool = Field(..., description="Whether the agent was deleted.")
|
||||
|
||||
|
||||
class SubmitToolOutputsToRunRequest(BaseModel):
|
||||
tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.")
|
||||
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
|
||||
# create assistant (MemGPT agent)
|
||||
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def create_assistant(request: CreateAssistantRequest = Body(...)):
|
||||
# TODO: create preset
|
||||
return OpenAIAssistant(
|
||||
id=DEFAULT_PRESET,
|
||||
name="default_preset",
|
||||
description=request.description,
|
||||
created_at=int(get_utc_time().timestamp()),
|
||||
model=request.model,
|
||||
instructions=request.instructions,
|
||||
tools=request.tools,
|
||||
file_ids=request.file_ids,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
@router.post("/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
|
||||
def create_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantFileRequest = Body(...),
|
||||
):
|
||||
# TODO: add file to assistant
|
||||
return AssistantFile(
|
||||
id=request.file_id,
|
||||
created_at=int(get_utc_time().timestamp()),
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
|
||||
@router.get("/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
|
||||
def list_assistants(
|
||||
limit: int = Query(1000, description="How many assistants to retrieve."),
|
||||
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: implement list assistants (i.e. list available MemGPT presets)
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
|
||||
def list_assistant_files(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
limit: int = Query(1000, description="How many files to retrieve."),
|
||||
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: list attached data sources to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def retrieve_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: get and return preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
|
||||
def retrieve_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: return data source attached to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def modify_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantRequest = Body(...),
|
||||
):
|
||||
# TODO: modify preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.delete("/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
|
||||
def delete_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: delete preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.delete("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
|
||||
def delete_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: delete source on preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads", tags=["threads"], response_model=OpenAIThread)
|
||||
def create_thread(request: CreateThreadRequest = Body(...)):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a memgpt agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
)
|
||||
# TODO: insert messages into recall memory
|
||||
return OpenAIThread(
|
||||
id=str(agent_state.id),
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
agent = server.get_agent(uuid.UUID(thread_id))
|
||||
return OpenAIThread(
|
||||
id=str(agent.id),
|
||||
created_at=int(agent.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
|
||||
def modify_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: ModifyThreadRequest = Body(...),
|
||||
):
|
||||
# TODO: add agent metadata so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.delete("/threads/{thread_id}", tags=["threads"], response_model=DeleteThreadResponse)
|
||||
def delete_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
# TODO: delete agent
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
|
||||
def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
):
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# create message object
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role=request.role,
|
||||
text=request.content,
|
||||
)
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
# add message to agent
|
||||
agent._append_to_messages([message])
|
||||
|
||||
openai_message = OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
return openai_message
|
||||
|
||||
@router.get("/threads/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
|
||||
def list_messages(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many messages to retrieve."),
|
||||
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
after_uuid = uuid.UUID(after) if before else None
|
||||
before_uuid = uuid.UUID(before) if before else None
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
reverse = True if (order == "desc") else False
|
||||
cursor, json_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after_uuid,
|
||||
before=before_uuid,
|
||||
order_by="created_at",
|
||||
reverse=reverse,
|
||||
)
|
||||
print(json_messages[0]["text"])
|
||||
# convert to openai style messages
|
||||
openai_messages = [
|
||||
OpenAIMessage(
|
||||
id=str(message["id"]),
|
||||
created_at=int(message["created_at"].timestamp()),
|
||||
content=[Text(text=message["text"])],
|
||||
role=message["role"],
|
||||
thread_id=str(message["agent_id"]),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
for message in json_messages
|
||||
]
|
||||
print("MESSAGES", openai_messages)
|
||||
# TODO: cast back to message objects
|
||||
return ListMessagesResponse(messages=openai_messages)
|
||||
|
||||
router.get("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
|
||||
def retrieve_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
):
|
||||
message_id = uuid.UUID(message_id)
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
message = server.get_agent_message(agent_id, message_id)
|
||||
return OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
|
||||
def retrieve_message_file(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: implement?
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
def modify_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
request: ModifyMessageRequest = Body(...),
|
||||
):
|
||||
# TODO: add metada field to message so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
agent.step(user_message=None) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(get_utc_time().timestamp())
|
||||
return OpenAIRun(
|
||||
id=run_id,
|
||||
created_at=create_time,
|
||||
thread_id=str(agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
status="completed", # TODO: eventaully allow offline execution
|
||||
expires_at=create_time,
|
||||
model=agent.agent_state.llm_config.model,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
@router.post("/threads/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_thread_and_run(
|
||||
request: CreateThreadRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add a bunch of messages and execute
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/threads/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
|
||||
def list_runs(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many runs to retrieve."),
|
||||
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
|
||||
def list_run_steps(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
limit: int = Query(1000, description="How many run steps to retrieve."),
|
||||
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def retrieve_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
|
||||
def retrieve_run_step(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
step_id: str = Path(..., description="The unique identifier of the run step."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def modify_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: ModifyRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
|
||||
def submit_tool_outputs_to_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: SubmitToolOutputsToRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
|
||||
def cancel_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
return router
|
||||
@@ -1,129 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
# from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
UsageStatistics,
|
||||
)
|
||||
from memgpt.server.rest_api.agents.message import send_message_to_agent
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_openai_chat_completions_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.post("/chat/completions", tags=["chat_completions"], response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Send a message to a MemGPT agent via a /chat/completions request
|
||||
|
||||
The bearer token will be used to identify the user.
|
||||
The 'user' field in the request should be set to the agent ID.
|
||||
"""
|
||||
agent_id = request.user
|
||||
if agent_id is None:
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
try:
|
||||
agent_id = uuid.UUID(agent_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="agent_id (in the 'user' field) must be a valid UUID")
|
||||
|
||||
messages = request.messages
|
||||
if messages is None:
|
||||
raise HTTPException(status_code=400, detail="'messages' field must not be empty")
|
||||
if len(messages) > 1:
|
||||
raise HTTPException(status_code=400, detail="'messages' field must be a list of length 1")
|
||||
if messages[0].role != "user":
|
||||
raise HTTPException(status_code=400, detail="'messages[0].role' must be a 'user'")
|
||||
|
||||
input_message = request.messages[0]
|
||||
if request.stream:
|
||||
print("Starting streaming OpenAI proxy response")
|
||||
|
||||
return await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
role=input_message.role,
|
||||
message=str(input_message.content),
|
||||
stream_legacy=False,
|
||||
# Turn streaming ON
|
||||
stream_steps=True,
|
||||
stream_tokens=True,
|
||||
# Turn on ChatCompletion mode (eg remaps send_message to content)
|
||||
chat_completion_mode=True,
|
||||
)
|
||||
|
||||
else:
|
||||
print("Starting non-streaming OpenAI proxy response")
|
||||
|
||||
response_messages = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
role=input_message.role,
|
||||
message=str(input_message.content),
|
||||
stream_legacy=False,
|
||||
# Turn streaming OFF
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
# print(response_messages)
|
||||
|
||||
# Concatenate all send_message outputs together
|
||||
id = ""
|
||||
visible_message_str = ""
|
||||
created_at = None
|
||||
for memgpt_msg in response_messages.messages:
|
||||
if "function_call" in memgpt_msg:
|
||||
memgpt_function_call = memgpt_msg["function_call"]
|
||||
if "name" in memgpt_function_call and memgpt_function_call["name"] == "send_message":
|
||||
try:
|
||||
memgpt_function_call_args = json.loads(memgpt_function_call["arguments"])
|
||||
visible_message_str += memgpt_function_call_args["message"]
|
||||
id = memgpt_function_call["id"]
|
||||
created_at = memgpt_msg["date"]
|
||||
except:
|
||||
print(f"Failed to parse MemGPT message: {str(memgpt_function_call)}")
|
||||
else:
|
||||
print(f"Skipping function_call: {str(memgpt_function_call)}")
|
||||
else:
|
||||
print(f"Skipping message: {str(memgpt_msg)}")
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=id,
|
||||
created=created_at if created_at else get_utc_time(),
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=visible_message_str,
|
||||
),
|
||||
)
|
||||
],
|
||||
# TODO add real usage
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=0,
|
||||
prompt_tokens=0,
|
||||
total_tokens=0,
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
return router
|
||||
@@ -1,70 +0,0 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.block import Persona as PersonaModel # TODO: modify
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListPersonasResponse(BaseModel):
|
||||
personas: List[PersonaModel] = Field(..., description="List of persona configurations.")
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
text: str = Field(..., description="The persona text.")
|
||||
name: str = Field(..., description="The name of the persona.")
|
||||
|
||||
|
||||
def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/personas", tags=["personas"], response_model=ListPersonasResponse)
|
||||
async def list_personas(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
personas = server.ms.list_personas(user_id=user_id)
|
||||
return ListPersonasResponse(personas=personas)
|
||||
|
||||
@router.post("/personas", tags=["personas"], response_model=PersonaModel)
|
||||
async def create_persona(
|
||||
request: CreatePersonaRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# TODO: disallow duplicate names for personas
|
||||
interface.clear()
|
||||
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
|
||||
persona_id = new_persona.id
|
||||
server.ms.create_persona(new_persona)
|
||||
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
|
||||
|
||||
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
|
||||
async def delete_persona(
|
||||
persona_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
persona = server.ms.delete_persona(persona_name, user_id=user_id)
|
||||
return persona
|
||||
|
||||
@router.get("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
|
||||
async def get_persona(
|
||||
persona_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
persona = server.ms.get_persona(persona_name, user_id=user_id)
|
||||
if persona is None:
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
return persona
|
||||
|
||||
return router
|
||||
115
memgpt/server/rest_api/routers/openai/assistants/assistants.py
Normal file
115
memgpt/server/rest_api/routers/openai/assistants/assistants.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.openai.openai import AssistantFile, OpenAIAssistant
|
||||
from memgpt.server.rest_api.routers.openai.assistants.schemas import (
|
||||
CreateAssistantFileRequest,
|
||||
CreateAssistantRequest,
|
||||
DeleteAssistantFileResponse,
|
||||
DeleteAssistantResponse,
|
||||
)
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
router = APIRouter(prefix="/v1/assistants", tags=["assistants"])
|
||||
|
||||
|
||||
# create assistant (MemGPT agent)
|
||||
@router.post("/", response_model=OpenAIAssistant)
|
||||
def create_assistant(request: CreateAssistantRequest = Body(...)):
|
||||
# TODO: create preset
|
||||
return OpenAIAssistant(
|
||||
id=DEFAULT_PRESET,
|
||||
name="default_preset",
|
||||
description=request.description,
|
||||
created_at=int(get_utc_time().timestamp()),
|
||||
model=request.model,
|
||||
instructions=request.instructions,
|
||||
tools=request.tools,
|
||||
file_ids=request.file_ids,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{assistant_id}/files", response_model=AssistantFile)
|
||||
def create_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantFileRequest = Body(...),
|
||||
):
|
||||
# TODO: add file to assistant
|
||||
return AssistantFile(
|
||||
id=request.file_id,
|
||||
created_at=int(get_utc_time().timestamp()),
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[OpenAIAssistant])
|
||||
def list_assistants(
|
||||
limit: int = Query(1000, description="How many assistants to retrieve."),
|
||||
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: implement list assistants (i.e. list available MemGPT presets)
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/files", response_model=List[AssistantFile])
|
||||
def list_assistant_files(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
limit: int = Query(1000, description="How many files to retrieve."),
|
||||
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: list attached data sources to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{assistant_id}", response_model=OpenAIAssistant)
|
||||
def retrieve_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: get and return preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/files/{file_id}", response_model=AssistantFile)
|
||||
def retrieve_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: return data source attached to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{assistant_id}", response_model=OpenAIAssistant)
|
||||
def modify_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantRequest = Body(...),
|
||||
):
|
||||
# TODO: modify preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}", response_model=DeleteAssistantResponse)
|
||||
def delete_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: delete preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}/files/{file_id}", response_model=DeleteAssistantFileResponse)
|
||||
def delete_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: delete source on preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
121
memgpt/server/rest_api/routers/openai/assistants/schemas.py
Normal file
121
memgpt/server/rest_api/routers/openai/assistants/schemas.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.openai.openai import (
|
||||
MessageRoleType,
|
||||
OpenAIMessage,
|
||||
OpenAIThread,
|
||||
ToolCall,
|
||||
ToolCallOutput,
|
||||
)
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str = Field(..., description="The model to use for the assistant.")
|
||||
name: str = Field(..., description="The name of the assistant.")
|
||||
description: str = Field(None, description="The description of the assistant.")
|
||||
instructions: str = Field(..., description="The instructions for the assistant.")
|
||||
tools: List[str] = Field(None, description="The tools used by the assistant.")
|
||||
file_ids: List[str] = Field(None, description="List of file IDs associated with the assistant.")
|
||||
metadata: dict = Field(None, description="Metadata associated with the assistant.")
|
||||
|
||||
# memgpt-only (not openai)
|
||||
embedding_model: str = Field(None, description="The model to use for the assistant.")
|
||||
|
||||
## TODO: remove
|
||||
# user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[List[str]] = Field(None, description="List of message IDs associated with the thread.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
# memgpt-only
|
||||
assistant_name: Optional[str] = Field(None, description="The name of the assistant (i.e. MemGPT preset)")
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the thread.")
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class ModifyRunRequest(BaseModel):
|
||||
metadata: dict = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: str = Field(..., description="Role of the message sender (either 'user' or 'system')")
|
||||
content: str = Field(..., description="The message content to be processed by the agent.")
|
||||
file_ids: Optional[List[str]] = Field(None, description="List of file IDs associated with the message.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
|
||||
|
||||
class UserMessageResponse(BaseModel):
|
||||
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
||||
|
||||
|
||||
class GetAgentMessagesRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
messages: List[OpenAIMessage] = Field(..., description="List of message objects.")
|
||||
|
||||
|
||||
class CreateAssistantFileRequest(BaseModel):
|
||||
file_id: str = Field(..., description="The unique identifier of the file.")
|
||||
|
||||
|
||||
class CreateRunRequest(BaseModel):
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
model: Optional[str] = Field(None, description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
additional_instructions: Optional[str] = Field(None, description="Additional instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class CreateThreadRunRequest(BaseModel):
|
||||
assistant_id: str = Field(..., description="The unique identifier of the assistant.")
|
||||
thread: OpenAIThread = Field(..., description="The thread to run.")
|
||||
model: str = Field(..., description="The model used by the run.")
|
||||
instructions: str = Field(..., description="The instructions for the run.")
|
||||
tools: Optional[List[ToolCall]] = Field(None, description="The tools used by the run (overrides assistant).")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the run.")
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the agent.")
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool = Field(..., description="Whether the agent was deleted.")
|
||||
|
||||
|
||||
class DeleteAssistantFileResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the file.")
|
||||
object: str = "assistant.file.deleted"
|
||||
deleted: bool = Field(..., description="Whether the file was deleted.")
|
||||
|
||||
|
||||
class DeleteThreadResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier of the agent.")
|
||||
object: str = "thread.deleted"
|
||||
deleted: bool = Field(..., description="Whether the agent was deleted.")
|
||||
|
||||
|
||||
class SubmitToolOutputsToRunRequest(BaseModel):
|
||||
tools_outputs: List[ToolCallOutput] = Field(..., description="The tool outputs to submit.")
|
||||
336
memgpt/server/rest_api/routers/openai/assistants/threads.py
Normal file
336
memgpt/server/rest_api/routers/openai/assistants/threads.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
|
||||
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.agent import CreateAgent
|
||||
from memgpt.schemas.enums import MessageRole
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.openai import (
|
||||
MessageFile,
|
||||
OpenAIMessage,
|
||||
OpenAIRun,
|
||||
OpenAIRunStep,
|
||||
OpenAIThread,
|
||||
Text,
|
||||
)
|
||||
from memgpt.server.rest_api.routers.openai.assistants.schemas import (
|
||||
CreateMessageRequest,
|
||||
CreateRunRequest,
|
||||
CreateThreadRequest,
|
||||
CreateThreadRunRequest,
|
||||
DeleteThreadResponse,
|
||||
ListMessagesResponse,
|
||||
ModifyMessageRequest,
|
||||
ModifyRunRequest,
|
||||
ModifyThreadRequest,
|
||||
OpenAIThread,
|
||||
SubmitToolOutputsToRunRequest,
|
||||
)
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
router = APIRouter(prefix="/v1/threads", tags=["threads"])
|
||||
|
||||
|
||||
@router.post("/", response_model=OpenAIThread)
|
||||
def create_thread(
|
||||
request: CreateThreadRequest = Body(...),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
actor = server.get_current_user()
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a memgpt agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(),
|
||||
user_id=actor.id,
|
||||
)
|
||||
# TODO: insert messages into recall memory
|
||||
return OpenAIThread(
|
||||
id=str(agent_state.id),
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
metadata={}, # TODO add metadata?
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=OpenAIThread)
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
|
||||
assert agent is not None
|
||||
return OpenAIThread(
|
||||
id=str(agent.id),
|
||||
created_at=int(agent.created_at.timestamp()),
|
||||
metadata={}, # TODO add metadata?
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=OpenAIThread)
|
||||
def modify_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: ModifyThreadRequest = Body(...),
|
||||
):
|
||||
# TODO: add agent metadata so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=DeleteThreadResponse)
|
||||
def delete_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
# TODO: delete agent
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
|
||||
def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
agent_id = thread_id
|
||||
# create message object
|
||||
message = Message(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
role=MessageRole(request.role),
|
||||
text=request.content,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
name=None,
|
||||
)
|
||||
agent = server._get_or_load_agent(agent_id=agent_id)
|
||||
# add message to agent
|
||||
agent._append_to_messages([message])
|
||||
|
||||
openai_message = OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=(message.text if message.text else ""))],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# TODO(sarah) fill in?
|
||||
run_id=None,
|
||||
file_ids=None,
|
||||
metadata=None,
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
return openai_message
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
|
||||
def list_messages(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many messages to retrieve."),
|
||||
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
after_uuid = after if before else None
|
||||
before_uuid = before if before else None
|
||||
agent_id = thread_id
|
||||
reverse = True if (order == "desc") else False
|
||||
json_messages = server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after_uuid,
|
||||
before=before_uuid,
|
||||
order_by="created_at",
|
||||
reverse=reverse,
|
||||
return_message_object=True,
|
||||
)
|
||||
assert isinstance(json_messages, List)
|
||||
assert all([isinstance(message, Message) for message in json_messages])
|
||||
assert isinstance(json_messages[0], Message)
|
||||
print(json_messages[0].text)
|
||||
# convert to openai style messages
|
||||
openai_messages = []
|
||||
for message in json_messages:
|
||||
assert isinstance(message, Message)
|
||||
openai_messages.append(
|
||||
OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=(message.text if message.text else ""))],
|
||||
role=str(message.role),
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# TODO(sarah) fill in?
|
||||
run_id=None,
|
||||
file_ids=None,
|
||||
metadata=None,
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
)
|
||||
print("MESSAGES", openai_messages)
|
||||
# TODO: cast back to message objects
|
||||
return ListMessagesResponse(messages=openai_messages)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
def retrieve_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
agent_id = thread_id
|
||||
message = server.get_agent_message(agent_id=agent_id, message_id=message_id)
|
||||
assert message is not None
|
||||
return OpenAIMessage(
|
||||
id=message_id,
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=(message.text if message.text else ""))],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# TODO(sarah) fill in?
|
||||
run_id=None,
|
||||
file_ids=None,
|
||||
metadata=None,
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
|
||||
def retrieve_message_file(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: implement?
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
def modify_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
request: ModifyMessageRequest = Body(...),
|
||||
):
|
||||
# TODO: add metada field to message so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateRunRequest = Body(...),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
server.get_current_user()
|
||||
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = thread_id
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server._get_or_load_agent(agent_id=agent_id)
|
||||
agent.step(user_message=None) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(get_utc_time().timestamp())
|
||||
return OpenAIRun(
|
||||
id=run_id,
|
||||
created_at=create_time,
|
||||
thread_id=str(agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
status="completed", # TODO: eventaully allow offline execution
|
||||
expires_at=create_time,
|
||||
model=agent.agent_state.llm_config.model,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_thread_and_run(
|
||||
request: CreateThreadRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add a bunch of messages and execute
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
|
||||
def list_runs(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many runs to retrieve."),
|
||||
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
|
||||
def list_run_steps(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
limit: int = Query(1000, description="How many run steps to retrieve."),
|
||||
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def retrieve_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
|
||||
def retrieve_run_step(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
step_id: str = Path(..., description="The unique identifier of the run step."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def modify_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: ModifyRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
|
||||
def submit_tool_outputs_to_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: SubmitToolOutputsToRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
|
||||
def cancel_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
from memgpt.schemas.enums import MessageRole
|
||||
from memgpt.schemas.memgpt_message import FunctionCall, MemGPTMessage
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
UsageStatistics,
|
||||
)
|
||||
|
||||
# TODO this belongs in a controller!
|
||||
from memgpt.server.rest_api.routers.v1.agents import send_message_to_agent
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(
|
||||
completion_request: ChatCompletionRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""Send a message to a MemGPT agent via a /chat/completions completion_request
|
||||
The bearer token will be used to identify the user.
|
||||
The 'user' field in the completion_request should be set to the agent ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
agent_id = completion_request.user
|
||||
if agent_id is None:
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
|
||||
messages = completion_request.messages
|
||||
if messages is None:
|
||||
raise HTTPException(status_code=400, detail="'messages' field must not be empty")
|
||||
if len(messages) > 1:
|
||||
raise HTTPException(status_code=400, detail="'messages' field must be a list of length 1")
|
||||
if messages[0].role != "user":
|
||||
raise HTTPException(status_code=400, detail="'messages[0].role' must be a 'user'")
|
||||
|
||||
input_message = completion_request.messages[0]
|
||||
if completion_request.stream:
|
||||
print("Starting streaming OpenAI proxy response")
|
||||
|
||||
# TODO(charles) support multimodal parts
|
||||
assert isinstance(input_message.content, str)
|
||||
|
||||
return await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
role=MessageRole(input_message.role),
|
||||
message=input_message.content,
|
||||
# Turn streaming ON
|
||||
stream_steps=True,
|
||||
stream_tokens=True,
|
||||
# Turn on ChatCompletion mode (eg remaps send_message to content)
|
||||
chat_completion_mode=True,
|
||||
return_message_object=False,
|
||||
)
|
||||
|
||||
else:
|
||||
print("Starting non-streaming OpenAI proxy response")
|
||||
|
||||
# TODO(charles) support multimodal parts
|
||||
assert isinstance(input_message.content, str)
|
||||
|
||||
response_messages = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
role=MessageRole(input_message.role),
|
||||
message=input_message.content,
|
||||
# Turn streaming OFF
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
return_message_object=False,
|
||||
)
|
||||
# print(response_messages)
|
||||
|
||||
# Concatenate all send_message outputs together
|
||||
id = ""
|
||||
visible_message_str = ""
|
||||
created_at = None
|
||||
for memgpt_msg in response_messages.messages:
|
||||
assert isinstance(memgpt_msg, MemGPTMessage)
|
||||
if isinstance(memgpt_msg, FunctionCall):
|
||||
if memgpt_msg.name and memgpt_msg.name == "send_message":
|
||||
try:
|
||||
memgpt_function_call_args = json.loads(memgpt_msg.arguments)
|
||||
visible_message_str += memgpt_function_call_args["message"]
|
||||
id = memgpt_msg.id
|
||||
created_at = memgpt_msg.date
|
||||
except:
|
||||
print(f"Failed to parse MemGPT message: {str(memgpt_msg)}")
|
||||
else:
|
||||
print(f"Skipping function_call: {str(memgpt_msg)}")
|
||||
else:
|
||||
print(f"Skipping message: {str(memgpt_msg)}")
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=id,
|
||||
created=created_at if created_at else get_utc_time(),
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=visible_message_str,
|
||||
),
|
||||
)
|
||||
],
|
||||
# TODO add real usage
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=0,
|
||||
prompt_tokens=0,
|
||||
total_tokens=0,
|
||||
),
|
||||
)
|
||||
return response
|
||||
15
memgpt/server/rest_api/routers/v1/__init__.py
Normal file
15
memgpt/server/rest_api/routers/v1/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from memgpt.server.rest_api.routers.v1.agents import router as agents_router
|
||||
from memgpt.server.rest_api.routers.v1.blocks import router as blocks_router
|
||||
from memgpt.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||
from memgpt.server.rest_api.routers.v1.llms import router as llm_router
|
||||
from memgpt.server.rest_api.routers.v1.sources import router as sources_router
|
||||
from memgpt.server.rest_api.routers.v1.tools import router as tools_router
|
||||
|
||||
ROUTERS = [
|
||||
tools_router,
|
||||
sources_router,
|
||||
agents_router,
|
||||
llm_router,
|
||||
blocks_router,
|
||||
jobs_router,
|
||||
]
|
||||
529
memgpt/server/rest_api/routers/v1/agents.py
Normal file
529
memgpt/server/rest_api/routers/v1/agents.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.memgpt_request import MemGPTRequest
|
||||
from memgpt.schemas.memgpt_response import MemGPTResponse
|
||||
from memgpt.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
CreateArchivalMemory,
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from memgpt.schemas.message import Message, UpdateMessage
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.schemas.source import Source
|
||||
from memgpt.server.rest_api.interface import StreamingServerInterface
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server, sse_async_generator
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import deduplicate
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[AgentState])
|
||||
def list_agents(
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.list_agents(user_id=actor.id)
|
||||
|
||||
|
||||
@router.post("/", response_model=AgentState)
|
||||
def create_agent(
|
||||
agent: CreateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
agent.user_id = actor.id
|
||||
|
||||
return server.create_agent(agent, user_id=actor.id)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}", response_model=AgentState)
|
||||
def update_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgentState = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
update_agent.id = agent_id
|
||||
return server.update_agent(update_agent, user_id=actor.id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState)
|
||||
def get_agent_state(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
return server.get_agent_state(user_id=actor.id, agent_id=agent_id)
|
||||
|
||||
|
||||
@router.delete("/{agent_id}")
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/sources", response_model=List[Source])
|
||||
def get_agent_sources(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
server.get_current_user()
|
||||
|
||||
return server.list_attached_sources(agent_id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/messages", response_model=List[Message])
|
||||
def get_agent_in_context_messages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the messages in the context of a specific agent.
|
||||
"""
|
||||
|
||||
return server.get_in_context_messages(agent_id=agent_id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory", response_model=Memory)
|
||||
def get_agent_memory(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
|
||||
return server.get_agent_memory(agent_id=agent_id)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/memory", response_model=Memory)
|
||||
def update_agent_memory(
|
||||
agent_id: str,
|
||||
request: Dict = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
|
||||
return memory
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary)
|
||||
def get_agent_recall_memory_summary(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the recall memory of a specific agent.
|
||||
"""
|
||||
|
||||
return server.get_recall_memory_summary(agent_id=agent_id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary)
|
||||
def get_agent_archival_memory_summary(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the archival memory of a specific agent.
|
||||
"""
|
||||
|
||||
return server.get_archival_memory_summary(agent_id=agent_id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival", response_model=List[Passage])
|
||||
def get_agent_archival_memory(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
# TODO need to add support for non-postgres here
|
||||
# chroma will throw:
|
||||
# raise ValueError("Cannot run get_all_cursor with chroma")
|
||||
|
||||
return server.get_agent_archival_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/archival", response_model=List[Passage])
|
||||
def insert_agent_archival_memory(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
|
||||
|
||||
|
||||
# TODO(ethan): query or path parameter for memory_id?
|
||||
# @router.delete("/{agent_id}/archival")
|
||||
@router.delete("/{agent_id}/archival/{memory_id}")
|
||||
def delete_agent_archival_memory(
|
||||
agent_id: str,
|
||||
memory_id: str,
|
||||
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
|
||||
|
||||
@router.get("/{agent_id}/messages", response_model=List[Message])
|
||||
def get_agent_messages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."),
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
before=before,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
return_message_object=msg_object,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=Message)
|
||||
def update_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
request: UpdateMessage = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
|
||||
return server.update_agent_message(agent_id=agent_id, request=request)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/messages", response_model=MemGPTResponse)
|
||||
async def send_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
request: MemGPTRequest = Body(...),
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
# TODO(charles): support sending multiple messages
|
||||
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
||||
message = request.messages[0]
|
||||
|
||||
return await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
role=message.role,
|
||||
message=message.text,
|
||||
stream_steps=request.stream_steps,
|
||||
stream_tokens=request.stream_tokens,
|
||||
return_message_object=request.return_message_object,
|
||||
)
|
||||
|
||||
|
||||
# TODO: move this into server.py?
|
||||
async def send_message_to_agent(
|
||||
server: SyncServer,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
role: MessageRole,
|
||||
message: str,
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
return_message_object: bool, # Should be True for Python Client, False for REST API
|
||||
chat_completion_mode: Optional[bool] = False,
|
||||
timestamp: Optional[datetime] = None,
|
||||
# related to whether or not we return `MemGPTMessage`s or `Message`s
|
||||
) -> Union[StreamingResponse, MemGPTResponse]:
|
||||
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
|
||||
# TODO: @charles is this the correct way to handle?
|
||||
include_final_message = True
|
||||
|
||||
# determine role
|
||||
if role == MessageRole.user:
|
||||
message_func = server.user_message
|
||||
elif role == MessageRole.system:
|
||||
message_func = server.system_message
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Bad role {role}")
|
||||
|
||||
if not stream_steps and stream_tokens:
|
||||
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
# For streaming response
|
||||
try:
|
||||
|
||||
# TODO: move this logic into server.py
|
||||
|
||||
# Get the generator object off of the agent's streaming interface
|
||||
# This will be attached to the POST SSE request used under-the-hood
|
||||
memgpt_agent = server._get_or_load_agent(agent_id=agent_id)
|
||||
streaming_interface = memgpt_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
|
||||
# Enable token-streaming within the request if desired
|
||||
streaming_interface.streaming_mode = stream_tokens
|
||||
# "chatcompletion mode" does some remapping and ignores inner thoughts
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
|
||||
# streaming_interface.allow_assistant_message = stream
|
||||
# streaming_interface.function_call_legacy_mode = stream
|
||||
|
||||
# Offload the synchronous message_func to a separate thread
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
if return_message_object:
|
||||
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
|
||||
raise NotImplementedError
|
||||
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, MemGPTMessage)
|
||||
or isinstance(message, LegacyMemGPTMessage)
|
||||
or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
break
|
||||
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
usage = await task
|
||||
|
||||
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
if return_message_object:
|
||||
message_ids = [m.id for m in filtered_stream]
|
||||
message_ids = deduplicate(message_ids)
|
||||
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
|
||||
return MemGPTResponse(messages=message_objs, usage=usage)
|
||||
else:
|
||||
return MemGPTResponse(messages=filtered_stream, usage=usage)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
##### MISSING #######
|
||||
|
||||
# @router.post("/{agent_id}/command")
|
||||
# def run_command(
|
||||
# agent_id: "UUID",
|
||||
# command: "AgentCommandRequest",
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_memgpt_server),
|
||||
# ):
|
||||
# """
|
||||
# Execute a command on a specified agent.
|
||||
|
||||
# This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
|
||||
|
||||
# Raises an HTTPException for any processing errors.
|
||||
# """
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# response = server.run_command(user_id=actor.id,
|
||||
# agent_id=agent_id,
|
||||
# command=command.command)
|
||||
|
||||
# return AgentCommandResponse(response=response)
|
||||
|
||||
# @router.get("/{agent_id}/config")
|
||||
# def get_agent_config(
|
||||
# agent_id: "UUID",
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_memgpt_server),
|
||||
# ):
|
||||
# """
|
||||
# Retrieve the configuration for a specific agent.
|
||||
|
||||
# This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||
# """
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
||||
## agent does not exist
|
||||
# raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
# agent_state = server.get_agent_config(user_id=actor.id, agent_id=agent_id)
|
||||
## get sources
|
||||
# attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
## configs
|
||||
# llm_config = LLMConfig(**vars(agent_state.llm_config))
|
||||
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
|
||||
|
||||
# return GetAgentResponse(
|
||||
# agent_state=AgentState(
|
||||
# id=agent_state.id,
|
||||
# name=agent_state.name,
|
||||
# user_id=agent_state.user_id,
|
||||
# llm_config=llm_config,
|
||||
# embedding_config=embedding_config,
|
||||
# state=agent_state.state,
|
||||
# created_at=int(agent_state.created_at.timestamp()),
|
||||
# tools=agent_state.tools,
|
||||
# system=agent_state.system,
|
||||
# metadata=agent_state._metadata,
|
||||
# ),
|
||||
# last_run_at=None, # TODO
|
||||
# sources=attached_sources,
|
||||
# )
|
||||
|
||||
# @router.patch("/{agent_id}/rename", response_model=GetAgentResponse)
|
||||
# def update_agent_name(
|
||||
# agent_id: "UUID",
|
||||
# agent_rename: AgentRenameRequest,
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_memgpt_server),
|
||||
# ):
|
||||
# """
|
||||
# Updates the name of a specific agent.
|
||||
|
||||
# This changes the name of the agent in the database but does NOT edit the agent's persona.
|
||||
# """
|
||||
# valid_name = agent_rename.agent_name
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# agent_state = server.rename_agent(user_id=actor.id, agent_id=agent_id, new_agent_name=valid_name)
|
||||
## get sources
|
||||
# attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
# llm_config = LLMConfig(**vars(agent_state.llm_config))
|
||||
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
|
||||
|
||||
# return GetAgentResponse(
|
||||
# agent_state=AgentState(
|
||||
# id=agent_state.id,
|
||||
# name=agent_state.name,
|
||||
# user_id=agent_state.user_id,
|
||||
# llm_config=llm_config,
|
||||
# embedding_config=embedding_config,
|
||||
# state=agent_state.state,
|
||||
# created_at=int(agent_state.created_at.timestamp()),
|
||||
# tools=agent_state.tools,
|
||||
# system=agent_state.system,
|
||||
# ),
|
||||
# last_run_at=None, # TODO
|
||||
# sources=attached_sources,
|
||||
# )
|
||||
|
||||
|
||||
# @router.get("/{agent_id}/archival/all", response_model=GetAgentArchivalMemoryResponse)
|
||||
# def get_agent_archival_memory_all(
|
||||
# agent_id: "UUID",
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_memgpt_server),
|
||||
# ):
|
||||
# """
|
||||
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
|
||||
# """
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# archival_memories = server.get_all_archival_memories(user_id=actor.id, agent_id=agent_id)
|
||||
# print("archival_memories:", archival_memories)
|
||||
# archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
|
||||
# return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
|
||||
73
memgpt/server/rest_api/routers/v1/blocks.py
Normal file
73
memgpt/server/rest_api/routers/v1/blocks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from memgpt.schemas.block import Block, CreateBlock, UpdateBlock
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
router = APIRouter(prefix="/blocks", tags=["blocks"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Block])
|
||||
def list_blocks(
|
||||
# query parameters
|
||||
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
|
||||
blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
|
||||
if blocks is None:
|
||||
return []
|
||||
return blocks
|
||||
|
||||
|
||||
@router.post("/", response_model=Block)
|
||||
def create_block(
|
||||
create_block: CreateBlock = Body(...),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
|
||||
create_block.user_id = actor.id
|
||||
return server.create_block(user_id=actor.id, request=create_block)
|
||||
|
||||
|
||||
@router.patch("/{block_id}", response_model=Block)
|
||||
def update_block(
|
||||
block_id: str,
|
||||
updated_block: UpdateBlock = Body(...),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
# actor = server.get_current_user()
|
||||
|
||||
updated_block.id = block_id
|
||||
return server.update_block(request=updated_block)
|
||||
|
||||
|
||||
# TODO: delete should not return anything
|
||||
@router.delete("/{block_id}", response_model=Block)
|
||||
def delete_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
|
||||
return server.delete_block(block_id=block_id)
|
||||
|
||||
|
||||
@router.get("/{block_id}", response_model=Block)
|
||||
def get_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
|
||||
block = server.get_block(block_id=block_id)
|
||||
if block is None:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
return block
|
||||
46
memgpt/server/rest_api/routers/v1/jobs.py
Normal file
46
memgpt/server/rest_api/routers/v1/jobs.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Job])
|
||||
def list_jobs(
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
List all jobs.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
# TODO: add filtering by status
|
||||
return server.list_jobs(user_id=actor.id)
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Job])
|
||||
def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.list_active_jobs(user_id=actor.id)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=Job)
|
||||
def get_job(
|
||||
job_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a job.
|
||||
"""
|
||||
|
||||
return server.get_job(job_id=job_id)
|
||||
28
memgpt/server/rest_api/routers/v1/llms.py
Normal file
28
memgpt/server/rest_api/routers/v1/llms.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[LLMConfig])
|
||||
def list_llm_backends(
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
|
||||
return server.list_models()
|
||||
|
||||
|
||||
@router.get("/embedding", response_model=List[EmbeddingConfig])
|
||||
def list_embedding_backends(
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
|
||||
return server.list_embedding_models()
|
||||
199
memgpt/server/rest_api/routers/v1/sources.py
Normal file
199
memgpt/server/rest_api/routers/v1/sources.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile
|
||||
|
||||
from memgpt.schemas.document import Document
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
|
||||
router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
|
||||
|
||||
@router.get("/{source_id}", response_model=Source)
|
||||
def get_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.get_source(source_id=source_id, user_id=actor.id)
|
||||
|
||||
|
||||
@router.get("/name/{source_name}", response_model=str)
|
||||
def get_source_id_by_name(
|
||||
source_name: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
source = server.get_source_id(source_name=source_name, user_id=actor.id)
|
||||
return source
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Source])
|
||||
def list_sources(
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.list_all_sources(user_id=actor.id)
|
||||
|
||||
|
||||
@router.post("/", response_model=Source)
|
||||
def create_source(
|
||||
source: SourceCreate,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.create_source(request=source, user_id=actor.id)
|
||||
|
||||
|
||||
@router.patch("/{source_id}", response_model=Source)
|
||||
def update_source(
|
||||
source_id: str,
|
||||
source: SourceUpdate,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
assert source.id == source_id, "Source ID in path must match ID in request body"
|
||||
|
||||
return server.update_source(request=source, user_id=actor.id)
|
||||
|
||||
|
||||
@router.delete("/{source_id}")
|
||||
def delete_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
server.delete_source(source_id=source_id, user_id=actor.id)
|
||||
|
||||
|
||||
@router.post("/{source_id}/attach", response_model=Source)
|
||||
def attach_source_to_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
|
||||
return source
|
||||
|
||||
|
||||
@router.post("/{source_id}/detach")
|
||||
def detach_source_from_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
) -> None:
|
||||
"""
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
|
||||
|
||||
|
||||
@router.post("/{source_id}/upload", response_model=Job)
|
||||
def upload_file_to_source(
|
||||
file: UploadFile,
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
bytes = file.file.read()
|
||||
|
||||
# create job
|
||||
job = Job(
|
||||
user_id=actor.id,
|
||||
metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id},
|
||||
completed_at=None,
|
||||
)
|
||||
job_id = job.id
|
||||
server.ms.create_job(job)
|
||||
|
||||
# create background task
|
||||
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
|
||||
|
||||
# return job information
|
||||
job = server.ms.get_job(job_id=job_id)
|
||||
assert job is not None, "Job not found"
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/{source_id}/passages", response_model=List[Passage])
|
||||
def list_passages(
|
||||
source_id: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
|
||||
return passages
|
||||
|
||||
|
||||
@router.get("/{source_id}/documents", response_model=List[Document])
|
||||
def list_documents(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
List all documents associated with a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
|
||||
return documents
|
||||
|
||||
|
||||
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
|
||||
# write the file to a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_path = os.path.join(str(tmpdirname), str(file.filename))
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(bytes)
|
||||
|
||||
server.load_file_to_source(source_id, file_path, job_id)
|
||||
103
memgpt/server/rest_api/routers/v1/tools.py
Normal file
103
memgpt/server/rest_api/routers/v1/tools.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
def delete_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# actor = server.get_current_user()
|
||||
server.delete_tool(tool_id=tool_id)
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=Tool)
|
||||
def get_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a tool by ID
|
||||
"""
|
||||
# actor = server.get_current_user()
|
||||
|
||||
tool = server.get_tool(tool_id=tool_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
|
||||
return tool
|
||||
|
||||
|
||||
@router.get("/name/{tool_name}", response_model=str)
|
||||
def get_tool_id(
|
||||
tool_name: str,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a tool ID by name
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
tool_id = server.get_tool_id(tool_name, user_id=actor.id)
|
||||
if tool_id is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool_id
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Tool])
|
||||
def list_all_tools(
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor.id
|
||||
|
||||
# TODO: add back when user-specific
|
||||
return server.list_tools(user_id=actor.id)
|
||||
# return server.ms.list_tools(user_id=None)
|
||||
|
||||
|
||||
@router.post("/", response_model=Tool)
|
||||
def create_tool(
|
||||
tool: ToolCreate = Body(...),
|
||||
update: bool = False,
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
|
||||
return server.create_tool(
|
||||
request=tool,
|
||||
# update=update,
|
||||
update=True,
|
||||
user_id=actor.id,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{tool_id}", response_model=Tool)
|
||||
def update_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
|
||||
server.get_current_user()
|
||||
return server.update_tool(request)
|
||||
109
memgpt/server/rest_api/routers/v1/users.py
Normal file
109
memgpt/server/rest_api/routers/v1/users.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from memgpt.schemas.api_key import APIKey, APIKeyCreate
|
||||
from memgpt.schemas.user import User, UserCreate
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
|
||||
# from memgpt.server.schemas.users import (
|
||||
# CreateAPIKeyRequest,
|
||||
# CreateAPIKeyResponse,
|
||||
# CreateUserRequest,
|
||||
# CreateUserResponse,
|
||||
# DeleteAPIKeyResponse,
|
||||
# DeleteUserResponse,
|
||||
# GetAllUsersResponse,
|
||||
# GetAPIKeysResponse,
|
||||
# )
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from memgpt.schemas.user import User
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[User])
|
||||
def get_all_users(
|
||||
cursor: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all users in the database
|
||||
"""
|
||||
try:
|
||||
next_cursor, users = server.ms.get_all_users(cursor=cursor, limit=limit)
|
||||
# processed_users = [{"user_id": user.id} for user in users]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return users
|
||||
|
||||
|
||||
@router.post("/", tags=["admin"], response_model=User)
|
||||
def create_user(
|
||||
request: UserCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Create a new user in the database
|
||||
"""
|
||||
|
||||
user = server.create_user(request)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=User)
|
||||
def delete_user(
|
||||
user_id: str = Query(..., description="The user_id key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
user = server.ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
server.ms.delete_user(user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/keys", response_model=APIKey)
|
||||
def create_new_api_key(
|
||||
create_key: APIKeyCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Create a new API key for a user
|
||||
"""
|
||||
api_key = server.create_api_key(create_key)
|
||||
return api_key
|
||||
|
||||
|
||||
@router.get("/keys", response_model=List[APIKey])
|
||||
def get_api_keys(
|
||||
user_id: str = Query(..., description="The unique identifier of the user."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all API keys for a user
|
||||
"""
|
||||
if server.ms.get_user(user_id=user_id) is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id)
|
||||
return api_keys
|
||||
|
||||
|
||||
@router.delete("/keys", response_model=APIKey)
|
||||
def delete_api_key(
|
||||
api_key: str = Query(..., description="The API key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
return server.delete_api_key(api_key)
|
||||
@@ -12,27 +12,24 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from memgpt.server.constants import REST_DEFAULT_PORT
|
||||
from memgpt.server.rest_api.admin.agents import setup_agents_admin_router
|
||||
from memgpt.server.rest_api.admin.tools import setup_tools_index_router
|
||||
from memgpt.server.rest_api.admin.users import setup_admin_router
|
||||
from memgpt.server.rest_api.agents.index import setup_agents_index_router
|
||||
from memgpt.server.rest_api.agents.memory import setup_agents_memory_router
|
||||
from memgpt.server.rest_api.agents.message import setup_agents_message_router
|
||||
from memgpt.server.rest_api.auth.index import setup_auth_router
|
||||
from memgpt.server.rest_api.block.index import setup_block_index_router
|
||||
from memgpt.server.rest_api.config.index import setup_config_index_router
|
||||
from memgpt.server.rest_api.auth.index import (
|
||||
setup_auth_router, # TODO: probably remove right?
|
||||
)
|
||||
from memgpt.server.rest_api.interface import StreamingServerInterface
|
||||
from memgpt.server.rest_api.jobs.index import setup_jobs_index_router
|
||||
from memgpt.server.rest_api.models.index import setup_models_index_router
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import (
|
||||
setup_openai_assistant_router,
|
||||
from memgpt.server.rest_api.routers.openai.assistants.assistants import (
|
||||
router as openai_assistants_router,
|
||||
)
|
||||
from memgpt.server.rest_api.openai_chat_completions.chat_completions import (
|
||||
setup_openai_chat_completions_router,
|
||||
from memgpt.server.rest_api.routers.openai.assistants.threads import (
|
||||
router as openai_threads_router,
|
||||
)
|
||||
from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import (
|
||||
router as openai_chat_completions_router,
|
||||
)
|
||||
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from memgpt.server.rest_api.routers.v1.users import (
|
||||
router as users_router, # TODO: decide on admin
|
||||
)
|
||||
from memgpt.server.rest_api.sources.index import setup_sources_index_router
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.server.rest_api.tools.index import setup_user_tools_index_router
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
@@ -44,8 +41,6 @@ Start the server with:
|
||||
poetry run uvicorn server:app --reload
|
||||
"""
|
||||
|
||||
# interface: QueuingInterface = QueuingInterface()
|
||||
# interface: StreamingServerInterface = StreamingServerInterface()
|
||||
interface: StreamingServerInterface = StreamingServerInterface
|
||||
server: SyncServer = SyncServer(default_interface_factory=lambda: interface())
|
||||
|
||||
@@ -66,10 +61,9 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
ADMIN_PREFIX = "/admin"
|
||||
ADMIN_API_PREFIX = "/api/admin"
|
||||
API_PREFIX = "/api"
|
||||
OPENAI_API_PREFIX = "/v1"
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
OPENAI_API_PREFIX = "/openai"
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -81,36 +75,51 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# v1_routes are the MemGPT API routes
|
||||
for route in v1_routes:
|
||||
app.include_router(route, prefix=API_PREFIX)
|
||||
# this gives undocumented routes for "latest" and bare api calls.
|
||||
# we should always tie this to the newest version of the api.
|
||||
app.include_router(route, prefix="", include_in_schema=False)
|
||||
app.include_router(route, prefix="/latest", include_in_schema=False)
|
||||
|
||||
# admin/users
|
||||
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
||||
|
||||
# openai
|
||||
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
|
||||
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
|
||||
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
|
||||
|
||||
# /api/auth endpoints
|
||||
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
# /admin/users endpoints
|
||||
app.include_router(setup_admin_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)])
|
||||
app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)])
|
||||
# # Serve static files
|
||||
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
|
||||
# app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="static")
|
||||
|
||||
# /api/admin/agents endpoints
|
||||
app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)])
|
||||
|
||||
# /api/agents endpoints
|
||||
app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_jobs_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
# # Serve favicon
|
||||
# @app.get("/favicon.ico")
|
||||
# async def favicon():
|
||||
# return FileResponse(os.path.join(static_files_path, "favicon.ico"))
|
||||
|
||||
# /api/config endpoints
|
||||
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
# /v1/assistants endpoints
|
||||
app.include_router(setup_openai_assistant_router(server, interface), prefix=OPENAI_API_PREFIX)
|
||||
# # Middleware to handle API routes first
|
||||
# @app.middleware("http")
|
||||
# async def handle_api_routes(request: Request, call_next):
|
||||
# if request.url.path.startswith(("/v1/", "/openai/")):
|
||||
# response = await call_next(request)
|
||||
# if response.status_code != 404:
|
||||
# return response
|
||||
# return await serve_spa(request.url.path)
|
||||
|
||||
|
||||
# # Catch-all route for SPA
|
||||
# async def serve_spa(full_path: str):
|
||||
# return FileResponse(os.path.join(static_files_path, "index.html"))
|
||||
|
||||
# /v1/chat/completions endpoints
|
||||
app.include_router(setup_openai_chat_completions_router(server, interface, password), prefix=OPENAI_API_PREFIX)
|
||||
|
||||
# / static files
|
||||
mount_static_files(app)
|
||||
|
||||
|
||||
|
||||
@@ -1,262 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Query,
|
||||
UploadFile,
|
||||
)
|
||||
|
||||
from memgpt.schemas.document import Document
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.schemas.passage import Passage
|
||||
|
||||
# schemas
|
||||
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate, UploadFile
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
"""
|
||||
Implement the following functions:
|
||||
* List all available sources
|
||||
* Create a new source
|
||||
* Delete a source
|
||||
* Upload a file to a server that is loaded into a specific source
|
||||
* Paginated get all passages from a source
|
||||
* Paginated get all documents from a source
|
||||
* Attach a source to an agent
|
||||
"""
|
||||
|
||||
|
||||
# class ListSourcesResponse(BaseModel):
|
||||
# sources: List[SourceModel] = Field(..., description="List of available sources.")
|
||||
#
|
||||
#
|
||||
# class CreateSourceRequest(BaseModel):
|
||||
# name: str = Field(..., description="The name of the source.")
|
||||
# description: Optional[str] = Field(None, description="The description of the source.")
|
||||
#
|
||||
#
|
||||
# class UploadFileToSourceRequest(BaseModel):
|
||||
# file: UploadFile = Field(..., description="The file to upload.")
|
||||
#
|
||||
#
|
||||
# class UploadFileToSourceResponse(BaseModel):
|
||||
# source: SourceModel = Field(..., description="The source the file was uploaded to.")
|
||||
# added_passages: int = Field(..., description="The number of passages added to the source.")
|
||||
# added_documents: int = Field(..., description="The number of documents added to the source.")
|
||||
#
|
||||
#
|
||||
# class GetSourcePassagesResponse(BaseModel):
|
||||
# passages: List[PassageModel] = Field(..., description="List of passages from the source.")
|
||||
#
|
||||
#
|
||||
# class GetSourceDocumentsResponse(BaseModel):
|
||||
# documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
|
||||
|
||||
|
||||
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
|
||||
# write the file to a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_path = os.path.join(tmpdirname, file.filename)
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(bytes)
|
||||
|
||||
server.load_file_to_source(source_id, file_path, job_id)
|
||||
|
||||
|
||||
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/sources/{source_id}", tags=["sources"], response_model=Source)
|
||||
async def get_source(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
interface.clear()
|
||||
source = server.get_source(source_id=source_id, user_id=user_id)
|
||||
return source
|
||||
|
||||
@router.get("/sources/name/{source_name}", tags=["sources"], response_model=str)
|
||||
async def get_source_id_by_name(
|
||||
source_name: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
interface.clear()
|
||||
source = server.get_source_id(source_name=source_name, user_id=user_id)
|
||||
return source
|
||||
|
||||
@router.get("/sources", tags=["sources"], response_model=List[Source])
|
||||
async def list_sources(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
try:
|
||||
sources = server.list_all_sources(user_id=user_id)
|
||||
return sources
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources", tags=["sources"], response_model=Source)
|
||||
async def create_source(
|
||||
request: SourceCreate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
return server.create_source(request=request, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources/{source_id}", tags=["sources"], response_model=Source)
|
||||
async def update_source(
|
||||
source_id: str,
|
||||
request: SourceUpdate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
return server.update_source(request=request, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.delete("/sources/{source_id}", tags=["sources"])
|
||||
async def delete_source(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_source(source_id=source_id, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=Source)
|
||||
async def attach_source_to_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
interface.clear()
|
||||
assert isinstance(agent_id, str), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||
assert isinstance(user_id, str), f"Expected user_id to be a UUID, got {user_id}"
|
||||
source = server.ms.get_source(source_id=source_id, user_id=user_id)
|
||||
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=user_id)
|
||||
return source
|
||||
|
||||
@router.post("/sources/{source_id}/detach", tags=["sources"])
|
||||
async def detach_source_from_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
||||
|
||||
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=Job)
|
||||
async def get_job(
|
||||
job_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a job.
|
||||
"""
|
||||
job = server.get_job(job_id=job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.")
|
||||
return job
|
||||
|
||||
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=Job)
|
||||
async def upload_file_to_source(
|
||||
# file: UploadFile = UploadFile(..., description="The file to upload."),
|
||||
file: UploadFile,
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
interface.clear()
|
||||
source = server.ms.get_source(source_id=source_id, user_id=user_id)
|
||||
bytes = file.file.read()
|
||||
|
||||
# create job
|
||||
# TODO: create server function
|
||||
job = Job(user_id=user_id, metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id})
|
||||
job_id = job.id
|
||||
server.ms.create_job(job)
|
||||
|
||||
# create background task
|
||||
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
|
||||
|
||||
# return job information
|
||||
job = server.ms.get_job(job_id=job_id)
|
||||
return job
|
||||
|
||||
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=List[Passage])
|
||||
async def list_passages(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
# TODO: check if paginated?
|
||||
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
|
||||
return passages
|
||||
|
||||
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=List[Document])
|
||||
async def list_documents(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all documents associated with a data source.
|
||||
"""
|
||||
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
|
||||
return documents
|
||||
|
||||
return router
|
||||
@@ -21,7 +21,8 @@ def mount_static_files(app: FastAPI):
|
||||
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
|
||||
if os.path.exists(static_files_path):
|
||||
app.mount(
|
||||
"/",
|
||||
# "/",
|
||||
"/app",
|
||||
SPAStaticFiles(
|
||||
directory=static_files_path,
|
||||
html=True,
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.delete("/tools/{tool_id}", tags=["tools"])
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
server.delete_tool(tool_id)
|
||||
|
||||
@router.get("/tools/{tool_id}", tags=["tools"], response_model=Tool)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
tool = server.get_tool(tool_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools/name/{tool_name}", tags=["tools"], response_model=str)
|
||||
async def get_tool_id(
|
||||
tool_name: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
tool = server.get_tool_id(tool_name, user_id=user_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=List[Tool])
|
||||
async def list_all_tools(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
return server.list_tools(user_id)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=Tool)
|
||||
async def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
return server.create_tool(request, user_id=user_id)
|
||||
|
||||
@router.post("/tools/{tool_id}", tags=["tools"], response_model=Tool)
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
try:
|
||||
# TODO: check that the user has access to this tool?
|
||||
return server.update_tool(request)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update tool: {e}")
|
||||
|
||||
return router
|
||||
@@ -5,6 +5,9 @@ from typing import AsyncGenerator, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memgpt.server.rest_api.interface import StreamingServerInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
# from memgpt.orm.user import User
|
||||
# from memgpt.orm.utilities import get_db_session
|
||||
|
||||
@@ -51,3 +54,13 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
|
||||
if finish_message:
|
||||
# Signal that the stream is complete
|
||||
yield sse_formatter(SSE_FINISH_MSG)
|
||||
|
||||
|
||||
# TODO: why does this double up the interface?
|
||||
def get_memgpt_server() -> SyncServer:
|
||||
server = SyncServer(default_interface_factory=lambda: StreamingServerInterface())
|
||||
return server
|
||||
|
||||
|
||||
def get_current_interface() -> StreamingServerInterface:
|
||||
return StreamingServerInterface
|
||||
|
||||
@@ -889,10 +889,10 @@ class SyncServer(Server):
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
template: Optional[bool] = None,
|
||||
template: bool = True,
|
||||
name: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
):
|
||||
) -> Optional[List[Block]]:
|
||||
|
||||
return self.ms.get_blocks(user_id=user_id, label=label, template=template, name=name, id=id)
|
||||
|
||||
@@ -1550,14 +1550,14 @@ class SyncServer(Server):
|
||||
|
||||
return sources_with_metadata
|
||||
|
||||
def get_tool(self, tool_id: str) -> Tool:
|
||||
def get_tool(self, tool_id: str) -> Optional[Tool]:
|
||||
"""Get tool by ID."""
|
||||
return self.ms.get_tool(tool_id=tool_id)
|
||||
|
||||
def get_tool_id(self, name: str, user_id: str) -> str:
|
||||
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
|
||||
"""Get tool ID from name and user_id."""
|
||||
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
|
||||
if not tool:
|
||||
if not tool or tool.id is None:
|
||||
return None
|
||||
return tool.id
|
||||
|
||||
@@ -1770,3 +1770,38 @@ class SyncServer(Server):
|
||||
# Get the current message
|
||||
memgpt_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
return memgpt_agent.retry_message()
|
||||
|
||||
# TODO(ethan) wire back to real method in future ORM PR
|
||||
def get_current_user(self) -> User:
|
||||
"""Returns the currently authed user.
|
||||
|
||||
Since server is the core gateway this needs to pass through server as the
|
||||
first touchpoint.
|
||||
"""
|
||||
# NOTE: same code as local client to get the default user
|
||||
config = MemGPTConfig.load()
|
||||
user_id = config.anon_clientid
|
||||
user = self.get_user(user_id)
|
||||
|
||||
if not user:
|
||||
user = self.create_user(UserCreate())
|
||||
|
||||
# # update config
|
||||
config.anon_clientid = str(user.id)
|
||||
config.save()
|
||||
|
||||
return user
|
||||
|
||||
def list_models(self) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
|
||||
# TODO support multiple models
|
||||
llm_config = self.server_llm_config
|
||||
return [llm_config]
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
|
||||
# TODO support multiple models
|
||||
embedding_config = self.server_embedding_config
|
||||
return [embedding_config]
|
||||
|
||||
@@ -2,12 +2,14 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.client.client import LocalClient, RESTClient
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.memory import ChatMemory
|
||||
|
||||
@@ -86,7 +88,7 @@ def agent(client):
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
def test_create_tool(client):
|
||||
def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
"""Test creation of a simple tool"""
|
||||
|
||||
def print_tool(message: str):
|
||||
@@ -111,7 +113,9 @@ def test_create_tool(client):
|
||||
print(f"Updated tools {[t.name for t in tools]}")
|
||||
|
||||
# check tool id
|
||||
tool = client.get_tool(tool.name)
|
||||
tool = client.get_tool(tool.id)
|
||||
assert tool is not None, "Expected tool to be created"
|
||||
assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}"
|
||||
|
||||
|
||||
# TODO: add back once we fix admin client tool creation
|
||||
|
||||
Reference in New Issue
Block a user