feat: add GET REST API route for listing tools (#1100)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
""" Metadata store for user/agent/data_source information"""
|
||||
|
||||
import os
|
||||
import inspect as python_inspect
|
||||
import uuid
|
||||
import secrets
|
||||
from typing import Optional, List
|
||||
@@ -9,8 +10,9 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
|
||||
from memgpt.utils import get_local_time, enforce_types
|
||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
|
||||
from memgpt.models.pydantic_models import PersonaModel, HumanModel
|
||||
from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
|
||||
from sqlalchemy import func
|
||||
@@ -517,6 +519,25 @@ class MetadataStore:
|
||||
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
available_functions = load_all_function_sets()
|
||||
print(available_functions)
|
||||
results = [
|
||||
ToolModel(
|
||||
name=k,
|
||||
json_schema=v["json_schema"],
|
||||
source_type="python",
|
||||
source_code=python_inspect.getsource(v["python_function"]),
|
||||
)
|
||||
for k, v in available_functions.items()
|
||||
]
|
||||
print(results)
|
||||
return results
|
||||
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
# return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
||||
with self.session_maker() as session:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional, Dict, Literal
|
||||
from pydantic import BaseModel, Field, Json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@@ -36,6 +36,14 @@ class PresetModel(BaseModel):
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
|
||||
|
||||
class ToolModel(BaseModel):
|
||||
# TODO move into database
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
json_schema: dict = Field(..., description="The JSON schema of the function.")
|
||||
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
|
||||
source_code: Optional[str] = Field(..., description="The source code of the function.")
|
||||
|
||||
|
||||
class AgentStateModel(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
|
||||
@@ -20,6 +20,7 @@ from memgpt.server.rest_api.models.index import setup_models_index_router
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
|
||||
from memgpt.server.rest_api.personas.index import setup_personas_index_router
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.server.rest_api.tools.index import setup_tools_index_router
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
"""
|
||||
@@ -92,6 +93,7 @@ app.include_router(setup_agents_message_router(server, interface, password), pre
|
||||
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
# /api/config endpoints
|
||||
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
0
memgpt/server/rest_api/tools/__init__.py
Normal file
0
memgpt/server/rest_api/tools/__init__.py
Normal file
35
memgpt/server/rest_api/tools/index.py
Normal file
35
memgpt/server/rest_api/tools/index.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
|
||||
|
||||
|
||||
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_tools(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
tools = server.ms.list_tools(user_id=user_id)
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user