feat: add GET REST API route for listing tools (#1100)

This commit is contained in:
Charles Packer
2024-03-05 22:11:24 -08:00
committed by GitHub
parent 7b8fcd3a42
commit bee8d1b72b
5 changed files with 68 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
""" Metadata store for user/agent/data_source information"""
import os
import inspect as python_inspect
import uuid
import secrets
from typing import Optional, List
@@ -9,8 +10,9 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
from memgpt.utils import get_local_time, enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.functions.functions import load_all_function_sets
from memgpt.models.pydantic_models import PersonaModel, HumanModel
from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
@@ -517,6 +519,25 @@ class MetadataStore:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
with self.session_maker() as session:
available_functions = load_all_function_sets()
print(available_functions)
results = [
ToolModel(
name=k,
json_schema=v["json_schema"],
source_type="python",
source_code=python_inspect.getsource(v["python_function"]),
)
for k, v in available_functions.items()
]
print(results)
return results
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
# return [r.to_record() for r in results]
@enforce_types
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
with self.session_maker() as session:

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field, Json
import uuid
from datetime import datetime
@@ -36,6 +36,14 @@ class PresetModel(BaseModel):
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
class ToolModel(BaseModel):
# TODO move into database
name: str = Field(..., description="The name of the function.")
json_schema: dict = Field(..., description="The JSON schema of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")
class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")

View File

@@ -20,6 +20,7 @@ from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.server import SyncServer
"""
@@ -92,6 +93,7 @@ app.include_router(setup_agents_message_router(server, interface, password), pre
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)

View File

View File

@@ -0,0 +1,35 @@
import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Depends, Body
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class ListToolsResponse(BaseModel):
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)
return router