feat: Add endpoint to add base tools to an org (#1971)

This commit is contained in:
Matthew Zhou
2024-11-01 15:42:43 -07:00
committed by GitHub
parent 5204a3a4e0
commit c81c3e8297
8 changed files with 76 additions and 24 deletions

View File

@@ -225,6 +225,9 @@ class AbstractClient(object):
def get_tool_id(self, name: str) -> Optional[str]:
raise NotImplementedError
def add_base_tools(self) -> List[Tool]:
raise NotImplementedError
def load_data(self, connector: DataConnector, source_name: str):
raise NotImplementedError
@@ -1271,6 +1274,13 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to get tool: {response.text}")
return response.json()
def add_base_tools(self) -> List[Tool]:
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to add base tools: {response.text}")
return [Tool(**tool) for tool in response.json()]
def create_tool(
self,
func: Callable,

View File

@@ -36,7 +36,6 @@ DEFAULT_PRESET = "memgpt_chat"
# Tools
BASE_TOOLS = [
"send_message",
# "pause_heartbeats",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",

View File

@@ -107,23 +107,32 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
"""
# Start the query
query = select(cls)
# Collect query conditions for better error reporting
query_conditions = []
# If an identifier is provided, add it to the query conditions
if identifier is not None:
identifier = cls.get_uid_from_identifier(identifier)
query = query.where(cls._id == identifier)
query_conditions.append(f"id='{identifier}'")
if kwargs:
query = query.filter_by(**kwargs)
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
if actor:
query = cls.apply_access_predicate(query, actor, access)
query_conditions.append(f"access level in {access} for actor='{actor}'")
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
query_conditions.append("is_deleted=False")
if found := db_session.execute(query).scalar():
return found
raise NoResultFound(f"{cls.__name__} with id {identifier} not found")
# Construct a detailed error message based on query conditions
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:

View File

@@ -104,3 +104,15 @@ def update_tool(
"""
actor = server.get_user_or_default(user_id=user_id)
return server.tool_manager.update_tool_by_id(tool_id, actor.id, request)
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
def add_base_tools(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Add base tools
"""
actor = server.get_user_or_default(user_id=user_id)
return server.tool_manager.add_base_tools(actor=actor)

View File

@@ -254,7 +254,7 @@ class SyncServer(Server):
self.default_org = self.organization_manager.create_default_organization()
self.default_user = self.user_manager.create_default_user()
self.add_default_blocks(self.default_user.id)
self.tool_manager.add_default_tools(module_name="base", actor=self.default_user)
self.tool_manager.add_base_tools(actor=self.default_user)
# If there is a default org/user
# This logic may have to change in the future

View File

@@ -18,6 +18,14 @@ from letta.utils import enforce_types
class ToolManager:
"""Manager class to handle business logic related to Tools."""
BASE_TOOL_NAMES = [
"send_message",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",
"archival_memory_search",
]
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.server import db_context
@@ -137,8 +145,9 @@ class ToolManager:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
def add_default_tools(self, actor: PydanticUser, module_name="base"):
"""Add default tools in {module_name}.py"""
def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py"""
module_name = "base"
full_module_name = f"letta.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
@@ -155,22 +164,28 @@ class ToolManager:
warnings.warn(err)
# create tool in db
tools = []
for name, schema in functions_to_schema.items():
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("letta-base")
if name in self.BASE_TOOL_NAMES:
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("letta-base")
# create to tool
self.create_or_update_tool(
ToolCreate(
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
),
actor=actor,
)
# create to tool
tools.append(
self.create_or_update_tool(
ToolCreate(
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
),
actor=actor,
)
)
return tools

View File

@@ -26,6 +26,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
@@ -299,7 +300,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
assert human.value == "Human text", "Creating human failed"
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}
@@ -321,6 +322,13 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: Ag
assert all(visited_ids.values())
def test_list_tools(client: Union[LocalClient, RESTClient]):
tools = client.add_base_tools()
tool_names = [t.name for t in tools]
expected = ToolManager.BASE_TOOL_NAMES
assert sorted(tool_names) == sorted(expected)
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
# clear sources
for source in client.list_sources():

View File

@@ -97,7 +97,6 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
[
"archival_memory_search",
"send_message",
"pause_heartbeats",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",