feat: Add endpoint to add base tools to an org (#1971)
This commit is contained in:
@@ -225,6 +225,9 @@ class AbstractClient(object):
|
||||
def get_tool_id(self, name: str) -> Optional[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_base_tools(self) -> List[Tool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1271,6 +1274,13 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def add_base_tools(self) -> List[Tool]:
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to add base tools: {response.text}")
|
||||
|
||||
return [Tool(**tool) for tool in response.json()]
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
func: Callable,
|
||||
|
||||
@@ -36,7 +36,6 @@ DEFAULT_PRESET = "memgpt_chat"
|
||||
# Tools
|
||||
BASE_TOOLS = [
|
||||
"send_message",
|
||||
# "pause_heartbeats",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
|
||||
@@ -107,23 +107,32 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
"""
|
||||
# Start the query
|
||||
query = select(cls)
|
||||
# Collect query conditions for better error reporting
|
||||
query_conditions = []
|
||||
|
||||
# If an identifier is provided, add it to the query conditions
|
||||
if identifier is not None:
|
||||
identifier = cls.get_uid_from_identifier(identifier)
|
||||
query = query.where(cls._id == identifier)
|
||||
query_conditions.append(f"id='{identifier}'")
|
||||
|
||||
if kwargs:
|
||||
query = query.filter_by(**kwargs)
|
||||
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
|
||||
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access)
|
||||
query_conditions.append(f"access level in {access} for actor='{actor}'")
|
||||
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
query_conditions.append("is_deleted=False")
|
||||
if found := db_session.execute(query).scalar():
|
||||
return found
|
||||
raise NoResultFound(f"{cls.__name__} with id {identifier} not found")
|
||||
|
||||
# Construct a detailed error message based on query conditions
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
||||
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
|
||||
if actor:
|
||||
|
||||
@@ -104,3 +104,15 @@ def update_tool(
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.tool_manager.update_tool_by_id(tool_id, actor.id, request)
|
||||
|
||||
|
||||
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
|
||||
def add_base_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Add base tools
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.tool_manager.add_base_tools(actor=actor)
|
||||
|
||||
@@ -254,7 +254,7 @@ class SyncServer(Server):
|
||||
self.default_org = self.organization_manager.create_default_organization()
|
||||
self.default_user = self.user_manager.create_default_user()
|
||||
self.add_default_blocks(self.default_user.id)
|
||||
self.tool_manager.add_default_tools(module_name="base", actor=self.default_user)
|
||||
self.tool_manager.add_base_tools(actor=self.default_user)
|
||||
|
||||
# If there is a default org/user
|
||||
# This logic may have to change in the future
|
||||
|
||||
@@ -18,6 +18,14 @@ from letta.utils import enforce_types
|
||||
class ToolManager:
|
||||
"""Manager class to handle business logic related to Tools."""
|
||||
|
||||
BASE_TOOL_NAMES = [
|
||||
"send_message",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.server import db_context
|
||||
@@ -137,8 +145,9 @@ class ToolManager:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def add_default_tools(self, actor: PydanticUser, module_name="base"):
|
||||
"""Add default tools in {module_name}.py"""
|
||||
def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""Add default tools in base.py"""
|
||||
module_name = "base"
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
@@ -155,22 +164,28 @@ class ToolManager:
|
||||
warnings.warn(err)
|
||||
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("letta-base")
|
||||
if name in self.BASE_TOOL_NAMES:
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("letta-base")
|
||||
|
||||
# create to tool
|
||||
self.create_or_update_tool(
|
||||
ToolCreate(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
# create to tool
|
||||
tools.append(
|
||||
self.create_or_update_tool(
|
||||
ToolCreate(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import model_settings
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
@@ -299,7 +300,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
|
||||
assert human.value == "Human text", "Creating human failed"
|
||||
|
||||
|
||||
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
|
||||
tools = client.list_tools()
|
||||
visited_ids = {t.id: False for t in tools}
|
||||
|
||||
@@ -321,6 +322,13 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: Ag
|
||||
assert all(visited_ids.values())
|
||||
|
||||
|
||||
def test_list_tools(client: Union[LocalClient, RESTClient]):
|
||||
tools = client.add_base_tools()
|
||||
tool_names = [t.name for t in tools]
|
||||
expected = ToolManager.BASE_TOOL_NAMES
|
||||
assert sorted(tool_names) == sorted(expected)
|
||||
|
||||
|
||||
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
|
||||
@@ -97,7 +97,6 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
[
|
||||
"archival_memory_search",
|
||||
"send_message",
|
||||
"pause_heartbeats",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
|
||||
Reference in New Issue
Block a user