feat: Add orm for Tools and clean up Tool logic (#1935)
This commit is contained in:
16
.github/workflows/code_style_checks.yml
vendored
16
.github/workflows/code_style_checks.yml
vendored
@@ -1,21 +1,13 @@
|
||||
name: Code Style Checks
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- synchronize
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
pull-requests: read
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
validation-checks:
|
||||
style-checks:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
4
.github/workflows/test_cli.yml
vendored
4
.github/workflows/test_cli.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Run CLI tests
|
||||
name: Test CLI
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -10,7 +10,7 @@ on:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
test-cli:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
|
||||
|
||||
39
.github/workflows/tests.yml
vendored
39
.github/workflows/tests.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Run All pytest Tests
|
||||
name: Unit Tests
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -13,7 +13,7 @@ on:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
unit-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
|
||||
@@ -37,6 +37,28 @@ jobs:
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
|
||||
|
||||
- name: Run LocalClient tests
|
||||
env:
|
||||
LETTA_PG_PORT: 8888
|
||||
LETTA_PG_USER: letta
|
||||
LETTA_PG_PASSWORD: letta
|
||||
LETTA_PG_DB: letta
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_local_client.py
|
||||
|
||||
- name: Run RESTClient tests
|
||||
env:
|
||||
LETTA_PG_PORT: 8888
|
||||
LETTA_PG_USER: letta
|
||||
LETTA_PG_PASSWORD: letta
|
||||
LETTA_PG_DB: letta
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_client.py
|
||||
|
||||
- name: Run server tests
|
||||
env:
|
||||
LETTA_PG_PORT: 8888
|
||||
@@ -70,6 +92,17 @@ jobs:
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_tools.py
|
||||
|
||||
- name: Run o1 agent tests
|
||||
env:
|
||||
LETTA_PG_PORT: 8888
|
||||
LETTA_PG_USER: letta
|
||||
LETTA_PG_PASSWORD: letta
|
||||
LETTA_PG_DB: letta
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_o1_agent.py
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
LETTA_PG_PORT: 8888
|
||||
@@ -80,4 +113,4 @@ jobs:
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers" tests
|
||||
poetry run pytest -s -vv -k "not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
|
||||
|
||||
@@ -5,7 +5,6 @@ from letta import create_client
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
"""
|
||||
Setup here.
|
||||
@@ -49,10 +48,7 @@ def main():
|
||||
from composio_langchain import Action
|
||||
|
||||
# Add the composio tool
|
||||
tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
|
||||
# create tool
|
||||
client.add_tool(tool)
|
||||
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
|
||||
persona = f"""
|
||||
My name is Letta.
|
||||
|
||||
@@ -2,8 +2,9 @@ import json
|
||||
import uuid
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
"""
|
||||
This example show how you can add CrewAI tools .
|
||||
@@ -21,14 +22,14 @@ def main():
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
|
||||
|
||||
example_website_scrape_tool = Tool.from_crewai(crewai_tool)
|
||||
tool_name = example_website_scrape_tool.name
|
||||
|
||||
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
# create tool
|
||||
client.add_tool(example_website_scrape_tool)
|
||||
example_website_scrape_tool = client.load_crewai_tool(crewai_tool)
|
||||
tool_name = example_website_scrape_tool.name
|
||||
|
||||
# Confirm that the tool is in
|
||||
tools = client.list_tools()
|
||||
|
||||
@@ -5,7 +5,6 @@ from letta import create_client
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
"""
|
||||
This example show how you can add LangChain tools .
|
||||
@@ -26,25 +25,22 @@ def main():
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
# Translate to memGPT Tool
|
||||
# Note the additional_imports_module_attr_map
|
||||
# We need to pass in a map of all the additional imports necessary to run this tool
|
||||
# Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool,
|
||||
# We need to also import WikipediaAPIWrapper
|
||||
# The map is a mapping of the module name to the attribute name
|
||||
# langchain_community.utilities.WikipediaAPIWrapper
|
||||
wikipedia_query_tool = Tool.from_langchain(
|
||||
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
)
|
||||
tool_name = wikipedia_query_tool.name
|
||||
|
||||
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
# create tool
|
||||
client.add_tool(wikipedia_query_tool)
|
||||
# Note the additional_imports_module_attr_map
|
||||
# We need to pass in a map of all the additional imports necessary to run this tool
|
||||
# Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool,
|
||||
# We need to also import WikipediaAPIWrapper
|
||||
# The map is a mapping of the module name to the attribute name
|
||||
# langchain_community.utilities.WikipediaAPIWrapper
|
||||
wikipedia_query_tool = client.load_langchain_tool(
|
||||
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
)
|
||||
tool_name = wikipedia_query_tool.name
|
||||
|
||||
# Confirm that the tool is in
|
||||
tools = client.list_tools()
|
||||
|
||||
@@ -240,7 +240,6 @@ class Agent(BaseAgent):
|
||||
assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}"
|
||||
|
||||
# link tools
|
||||
self.tools = tools
|
||||
self.link_tools(tools)
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
@@ -1521,10 +1520,6 @@ def save_agent(agent: Agent, ms: MetadataStore):
|
||||
else:
|
||||
ms.create_agent(agent_state)
|
||||
|
||||
for tool in agent.tools:
|
||||
if ms.get_tool(tool_name=tool.name, user_id=tool.user_id) is None:
|
||||
ms.create_tool(tool)
|
||||
|
||||
agent.agent_state = ms.get_agent(agent_id=agent_id)
|
||||
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ def run(
|
||||
# read user id from config
|
||||
ms = MetadataStore(config)
|
||||
client = create_client()
|
||||
server = client.server
|
||||
|
||||
# determine agent to use, if not provided
|
||||
if not yes and not agent:
|
||||
@@ -217,7 +218,9 @@ def run(
|
||||
)
|
||||
|
||||
# create agent
|
||||
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
|
||||
tools = [
|
||||
server.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=client.user_id) for tool_name in agent_state.tools
|
||||
]
|
||||
letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
|
||||
|
||||
else: # create new agent
|
||||
@@ -297,7 +300,7 @@ def run(
|
||||
)
|
||||
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
|
||||
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
|
||||
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
|
||||
tools = [server.tool_manager.get_tool_by_name_and_user_id(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
|
||||
|
||||
letta_agent = Agent(
|
||||
interface=interface(),
|
||||
|
||||
@@ -182,6 +182,15 @@ class AbstractClient(object):
|
||||
def delete_human(self, id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_composio_tool(self, action: "ActionType") -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
@@ -1298,12 +1307,6 @@ class RESTClient(AbstractClient):
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
|
||||
# TODO: Check if tool already exists
|
||||
# if name:
|
||||
# tool_id = self.get_tool_id(tool_name=name)
|
||||
# if tool_id:
|
||||
# raise ValueError(f"Tool with name {name} (id={tool_id}) already exists")
|
||||
|
||||
# call server function
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
|
||||
@@ -1337,7 +1340,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
source_type = "python"
|
||||
|
||||
request = ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name)
|
||||
request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update tool: {response.text}")
|
||||
@@ -1394,6 +1397,7 @@ class RESTClient(AbstractClient):
|
||||
params["cursor"] = str(cursor)
|
||||
if limit:
|
||||
params["limit"] = limit
|
||||
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list tools: {response.text}")
|
||||
@@ -1503,6 +1507,7 @@ class LocalClient(AbstractClient):
|
||||
self,
|
||||
auto_save: bool = False,
|
||||
user_id: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
default_llm_config: Optional[LLMConfig] = None,
|
||||
default_embedding_config: Optional[EmbeddingConfig] = None,
|
||||
@@ -1529,15 +1534,19 @@ class LocalClient(AbstractClient):
|
||||
self.interface = QueuingInterface(debug=debug)
|
||||
self.server = SyncServer(default_interface_factory=lambda: self.interface)
|
||||
|
||||
# save org_id that `LocalClient` is associated with
|
||||
if org_id:
|
||||
self.org_id = org_id
|
||||
else:
|
||||
self.org_id = self.server.organization_manager.DEFAULT_ORG_ID
|
||||
# save user_id that `LocalClient` is associated with
|
||||
if user_id:
|
||||
self.user_id = user_id
|
||||
else:
|
||||
# get default user
|
||||
self.user_id = self.server.get_default_user().id
|
||||
self.user_id = self.server.user_manager.DEFAULT_USER_ID
|
||||
|
||||
# agents
|
||||
|
||||
def list_agents(self) -> List[AgentState]:
|
||||
self.interface.clear()
|
||||
|
||||
@@ -2186,50 +2195,27 @@ class LocalClient(AbstractClient):
|
||||
self.server.delete_block(id)
|
||||
|
||||
# tools
|
||||
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
|
||||
tool_create = ToolCreate.from_langchain(
|
||||
langchain_tool=langchain_tool,
|
||||
user_id=self.user_id,
|
||||
organization_id=self.org_id,
|
||||
additional_imports_module_attr_map=additional_imports_module_attr_map,
|
||||
)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create)
|
||||
|
||||
# TODO: merge this into create_tool
|
||||
def add_tool(self, tool: Tool, update: Optional[bool] = True) -> Tool:
|
||||
"""
|
||||
Adds a tool directly.
|
||||
def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
|
||||
tool_create = ToolCreate.from_crewai(
|
||||
crewai_tool=crewai_tool,
|
||||
additional_imports_module_attr_map=additional_imports_module_attr_map,
|
||||
user_id=self.user_id,
|
||||
organization_id=self.org_id,
|
||||
)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create)
|
||||
|
||||
Args:
|
||||
tool (Tool): The tool to add.
|
||||
update (bool, optional): Update the tool if it already exists. Defaults to True.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if self.tool_with_name_and_user_id_exists(tool):
|
||||
if update:
|
||||
return self.server.update_tool(
|
||||
ToolUpdate(
|
||||
id=tool.id,
|
||||
description=tool.description,
|
||||
source_type=tool.source_type,
|
||||
source_code=tool.source_code,
|
||||
tags=tool.tags,
|
||||
json_schema=tool.json_schema,
|
||||
name=tool.name,
|
||||
),
|
||||
self.user_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Tool with id={tool.id} and name={tool.name}already exists")
|
||||
else:
|
||||
# call server function
|
||||
return self.server.create_tool(
|
||||
ToolCreate(
|
||||
id=tool.id,
|
||||
description=tool.description,
|
||||
source_type=tool.source_type,
|
||||
source_code=tool.source_code,
|
||||
name=tool.name,
|
||||
json_schema=tool.json_schema,
|
||||
tags=tool.tags,
|
||||
),
|
||||
user_id=self.user_id,
|
||||
update=update,
|
||||
)
|
||||
def load_composio_tool(self, action: "ActionType") -> Tool:
|
||||
tool_create = ToolCreate.from_composio(action=action, user_id=self.user_id, organization_id=self.org_id)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create)
|
||||
|
||||
# TODO: Use the above function `add_tool` here as there is duplicate logic
|
||||
def create_tool(
|
||||
@@ -2262,11 +2248,16 @@ class LocalClient(AbstractClient):
|
||||
tags = []
|
||||
|
||||
# call server function
|
||||
return self.server.create_tool(
|
||||
# ToolCreate(source_type=source_type, source_code=source_code, name=tool_name, json_schema=json_schema, tags=tags),
|
||||
ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags, terminal=terminal),
|
||||
user_id=self.user_id,
|
||||
update=update,
|
||||
return self.server.tool_manager.create_or_update_tool(
|
||||
ToolCreate(
|
||||
user_id=self.user_id,
|
||||
organization_id=self.org_id,
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
name=name,
|
||||
tags=tags,
|
||||
terminal=terminal,
|
||||
),
|
||||
)
|
||||
|
||||
def update_tool(
|
||||
@@ -2288,16 +2279,17 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
tool (Tool): Updated tool
|
||||
"""
|
||||
if func:
|
||||
source_code = parse_source_code(func)
|
||||
else:
|
||||
source_code = None
|
||||
update_data = {
|
||||
"source_type": "python", # Always include source_type
|
||||
"source_code": parse_source_code(func) if func else None,
|
||||
"tags": tags,
|
||||
"name": name,
|
||||
}
|
||||
|
||||
source_type = "python"
|
||||
# Filter out any None values from the dictionary
|
||||
update_data = {key: value for key, value in update_data.items() if value is not None}
|
||||
|
||||
return self.server.update_tool(
|
||||
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
|
||||
)
|
||||
return self.server.tool_manager.update_tool_by_id(id, ToolUpdate(**update_data))
|
||||
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
|
||||
"""
|
||||
@@ -2306,11 +2298,11 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
tools (List[Tool]): List of tools
|
||||
"""
|
||||
return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id)
|
||||
return self.server.tool_manager.list_tools_for_org(cursor=cursor, limit=limit, organization_id=self.org_id)
|
||||
|
||||
def get_tool(self, id: str) -> Optional[Tool]:
|
||||
"""
|
||||
Get a tool give its ID.
|
||||
Get a tool given its ID.
|
||||
|
||||
Args:
|
||||
id (str): ID of the tool
|
||||
@@ -2318,7 +2310,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
tool (Tool): Tool
|
||||
"""
|
||||
return self.server.get_tool(id)
|
||||
return self.server.tool_manager.get_tool_by_id(id)
|
||||
|
||||
def delete_tool(self, id: str):
|
||||
"""
|
||||
@@ -2327,11 +2319,11 @@ class LocalClient(AbstractClient):
|
||||
Args:
|
||||
id (str): ID of the tool
|
||||
"""
|
||||
return self.server.delete_tool(id)
|
||||
return self.server.tool_manager.delete_tool_by_id(id)
|
||||
|
||||
def get_tool_id(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Get the ID of a tool
|
||||
Get the ID of a tool from its name. The client will use the org_id it is configured with.
|
||||
|
||||
Args:
|
||||
name (str): Name of the tool
|
||||
@@ -2339,19 +2331,8 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
id (str): ID of the tool (`None` if not found)
|
||||
"""
|
||||
return self.server.get_tool_id(name, self.user_id)
|
||||
|
||||
def tool_with_name_and_user_id_exists(self, tool: Tool) -> bool:
|
||||
"""
|
||||
Check if the tool with name and user_id exists
|
||||
|
||||
Args:
|
||||
tool (Tool): the tool
|
||||
|
||||
Returns:
|
||||
(bool): True if the id exists, False otherwise.
|
||||
"""
|
||||
return self.server.tool_with_name_and_user_id_exists(tool, self.user_id)
|
||||
tool = self.server.tool_manager.get_tool_by_name_and_org_id(tool_name=name, organization_id=self.org_id)
|
||||
return tool.id
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
"""
|
||||
|
||||
@@ -13,13 +13,13 @@ from letta.constants import (
|
||||
DEFAULT_HUMAN,
|
||||
DEFAULT_PERSONA,
|
||||
DEFAULT_PRESET,
|
||||
DEFAULT_USER_ID,
|
||||
LETTA_DIR,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.services.user_manager import UserManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -45,7 +45,7 @@ def set_field(config, section, field, value):
|
||||
@dataclass
|
||||
class LettaConfig:
|
||||
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(LETTA_DIR, "config")
|
||||
anon_clientid: str = DEFAULT_USER_ID
|
||||
anon_clientid: str = UserManager.DEFAULT_USER_ID
|
||||
|
||||
# preset
|
||||
preset: str = DEFAULT_PRESET # TODO: rename to system prompt
|
||||
|
||||
@@ -3,15 +3,6 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
|
||||
|
||||
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
|
||||
|
||||
# Defaults
|
||||
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
|
||||
# This UUID follows the UUID4 rules:
|
||||
# The 13th character (4) indicates it's version 4.
|
||||
# The first character of the third segment (8) ensures the variant is correctly set.
|
||||
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_USER_NAME = "default_user"
|
||||
DEFAULT_ORG_NAME = "default_org"
|
||||
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
# Example full message:
|
||||
|
||||
@@ -3,9 +3,33 @@ import inspect
|
||||
import os
|
||||
from textwrap import dedent # remove indentation
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.schemas.tool import ToolCreate
|
||||
|
||||
|
||||
def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
|
||||
# auto-generate openai schema
|
||||
try:
|
||||
# Define a custom environment with necessary imports
|
||||
env = {
|
||||
"Optional": Optional, # Add any other required imports here
|
||||
}
|
||||
|
||||
env.update(globals())
|
||||
exec(tool_create.source_code, env)
|
||||
|
||||
# get available functions
|
||||
functions = [f for f in env if callable(env[f])]
|
||||
|
||||
# TODO: not sure if this always works
|
||||
func = env[functions[-1]]
|
||||
json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name)
|
||||
return json_schema
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to execute source code: {e}")
|
||||
|
||||
|
||||
def parse_source_code(func) -> str:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import humps
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -8,7 +9,7 @@ def generate_composio_tool_wrapper(action: "ActionType") -> tuple[str, str]:
|
||||
tool_instantiation_str = f"composio_toolset.get_tools(actions=[Action.{str(action)}])[0]"
|
||||
|
||||
# Generate func name
|
||||
func_name = f"run_{action.name.lower()}"
|
||||
func_name = action.name.lower()
|
||||
|
||||
wrapper_function_str = f"""
|
||||
def {func_name}(**kwargs):
|
||||
@@ -40,7 +41,7 @@ def generate_langchain_tool_wrapper(
|
||||
|
||||
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
func_name = humps.decamelize(tool_name)
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
@@ -70,7 +71,7 @@ def generate_crewai_tool_wrapper(tool: "CrewAIBaseTool", additional_imports_modu
|
||||
|
||||
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
func_name = humps.decamelize(tool_name)
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
|
||||
@@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model):
|
||||
}
|
||||
|
||||
|
||||
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None):
|
||||
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict:
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
|
||||
@@ -139,7 +139,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
|
||||
|
||||
|
||||
def generate_schema_from_args_schema(
|
||||
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None
|
||||
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
properties = {}
|
||||
required = []
|
||||
@@ -163,4 +163,12 @@ def generate_schema_from_args_schema(
|
||||
"parameters": {"type": "object", "properties": properties, "required": required},
|
||||
}
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
if append_heartbeat:
|
||||
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
function_call_json["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return function_call_json
|
||||
|
||||
@@ -14,8 +14,6 @@ from sqlalchemy import (
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
asc,
|
||||
or_,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
@@ -32,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types, get_utc_time, printd
|
||||
@@ -359,37 +356,6 @@ class BlockModel(Base):
|
||||
)
|
||||
|
||||
|
||||
class ToolModel(Base):
|
||||
__tablename__ = "tools"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
user_id = Column(String)
|
||||
description = Column(String)
|
||||
source_type = Column(String)
|
||||
source_code = Column(String)
|
||||
json_schema = Column(JSON)
|
||||
module = Column(String)
|
||||
tags = Column(JSON)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tool(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Tool:
|
||||
return Tool(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
user_id=self.user_id,
|
||||
description=self.description,
|
||||
source_type=self.source_type,
|
||||
source_code=self.source_code,
|
||||
json_schema=self.json_schema,
|
||||
module=self.module,
|
||||
tags=self.tags,
|
||||
)
|
||||
|
||||
|
||||
class JobModel(Base):
|
||||
__tablename__ = "jobs"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
@@ -516,14 +482,6 @@ class MetadataStore:
|
||||
session.add(BlockModel(**vars(block)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_tool(self, tool: Tool):
|
||||
with self.session_maker() as session:
|
||||
if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None:
|
||||
raise ValueError(f"Tool with name {tool.name} already exists")
|
||||
session.add(ToolModel(**vars(tool)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent: AgentState):
|
||||
with self.session_maker() as session:
|
||||
@@ -556,18 +514,6 @@ class MetadataStore:
|
||||
session.add(BlockModel(**vars(block)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def update_tool(self, tool_id: str, tool: Tool):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool(self, tool_id: str):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]):
|
||||
with self.session_maker() as session:
|
||||
@@ -612,23 +558,6 @@ class MetadataStore:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
# Query for public tools or user-specific tools
|
||||
query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id))
|
||||
|
||||
# Apply cursor if provided (assuming cursor is an ID)
|
||||
if cursor:
|
||||
query = query.filter(ToolModel.id > cursor)
|
||||
|
||||
# Order by ID and apply limit
|
||||
results = query.order_by(asc(ToolModel.id)).limit(limit).all()
|
||||
|
||||
# Convert to records
|
||||
res = [r.to_record() for r in results]
|
||||
return res
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: str) -> List[AgentState]:
|
||||
with self.session_maker() as session:
|
||||
@@ -672,32 +601,6 @@ class MetadataStore:
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_tool(
|
||||
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
|
||||
) -> Optional[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
if tool_id:
|
||||
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
|
||||
else:
|
||||
assert tool_name is not None
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
# assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_block(self, block_id: str) -> Optional[Block]:
|
||||
with self.session_maker() as session:
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.schemas.tool import Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
|
||||
def send_thinking_message(self: Agent, message: str) -> Optional[str]:
|
||||
def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
|
||||
"""
|
||||
Sends a thinking message so that the model can reason out loud before responding.
|
||||
|
||||
@@ -24,7 +24,7 @@ def send_thinking_message(self: Agent, message: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def send_final_message(self: Agent, message: str) -> Optional[str]:
|
||||
def send_final_message(self: "Agent", message: str) -> Optional[str]:
|
||||
"""
|
||||
Sends a final message to the human user after thinking for a while.
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""__all__ acts as manual import management to avoid collisions and circular imports."""
|
||||
|
||||
# from letta.orm.agent import Agent
|
||||
# from letta.orm.users_agents import UsersAgents
|
||||
# from letta.orm.blocks_agents import BlocksAgents
|
||||
# from letta.orm.token import Token
|
||||
# from letta.orm.source import Source
|
||||
# from letta.orm.document import Document
|
||||
# from letta.orm.passage import Passage
|
||||
# from letta.orm.memory_templates import MemoryTemplate, HumanMemoryTemplate, PersonaMemoryTemplate
|
||||
# from letta.orm.sources_agents import SourcesAgents
|
||||
# from letta.orm.tools_agents import ToolsAgents
|
||||
# from letta.orm.job import Job
|
||||
# from letta.orm.block import Block
|
||||
# from letta.orm.message import Message
|
||||
|
||||
@@ -55,7 +55,6 @@ class OrganizationMixin(Base):
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
# Changed _organization_id to store string (still a valid UUID4 string)
|
||||
_organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id"))
|
||||
|
||||
@property
|
||||
@@ -65,3 +64,19 @@ class OrganizationMixin(Base):
|
||||
@organization_id.setter
|
||||
def organization_id(self, value: str) -> None:
|
||||
_relation_setter(self, "organization", value)
|
||||
|
||||
|
||||
class UserMixin(Base):
|
||||
"""Mixin for models that belong to a user."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
_user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id"))
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return _relation_getter(self, "user")
|
||||
|
||||
@user_id.setter
|
||||
def user_id(self, value: str) -> None:
|
||||
_relation_setter(self, "user", value)
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
|
||||
|
||||
@@ -19,6 +20,7 @@ class Organization(SqlalchemyBase):
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
||||
|
||||
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Map these relationships later when we actually make these models
|
||||
# below is just a suggestion
|
||||
|
||||
@@ -184,21 +184,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
logger.warning("to_record is deprecated, use to_pydantic instead.")
|
||||
return self.to_pydantic()
|
||||
|
||||
# TODO: Look into this later and maybe add back?
|
||||
# def _infer_organization(self, db_session: "Session") -> None:
|
||||
# """🪄 MAGIC ALERT! 🪄
|
||||
# Because so much of the original API is centered around user scopes,
|
||||
# this allows us to continue with that scope and then infer the org from the creating user.
|
||||
#
|
||||
# IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
|
||||
# If not do nothing to the object. Mutates in place.
|
||||
# """
|
||||
# if self.created_by_id and hasattr(self, "_organization_id"):
|
||||
# try:
|
||||
# from letta.orm.user import User # to avoid circular import
|
||||
#
|
||||
# created_by = User.read(db_session, self.created_by_id)
|
||||
# except NoResultFound:
|
||||
# logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
|
||||
# return
|
||||
# self._organization_id = created_by._organization_id
|
||||
def _infer_organization(self, db_session: "Session") -> None:
|
||||
"""🪄 MAGIC ALERT! 🪄
|
||||
Because so much of the original API is centered around user scopes,
|
||||
this allows us to continue with that scope and then infer the org from the creating user.
|
||||
|
||||
IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
|
||||
If not do nothing to the object. Mutates in place.
|
||||
"""
|
||||
if self.created_by_id and hasattr(self, "_organization_id"):
|
||||
try:
|
||||
from letta.orm.user import User # to avoid circular import
|
||||
|
||||
created_by = User.read(db_session, self.created_by_id)
|
||||
except NoResultFound:
|
||||
logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
|
||||
return
|
||||
self._organization_id = created_by._organization_id
|
||||
|
||||
54
letta/orm/tool.py
Normal file
54
letta/orm/tool.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
# TODO everything in functions should live in this model
|
||||
from letta.orm.enums import ToolSourceType
|
||||
from letta.orm.mixins import OrganizationMixin, UserMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.user import User
|
||||
|
||||
|
||||
class Tool(SqlalchemyBase, OrganizationMixin, UserMixin):
|
||||
"""Represents an available tool that the LLM can invoke.
|
||||
|
||||
NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools
|
||||
that are always available, and a subset scoped to the organization. Alternatively, we could use the apply_access_predicate to build
|
||||
more granular permissions.
|
||||
"""
|
||||
|
||||
__tablename__ = "tool"
|
||||
__pydantic_model__ = PydanticTool
|
||||
|
||||
# Add unique constraint on (name, _organization_id)
|
||||
# An organization should not have multiple tools with the same name
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name", "_organization_id", name="uix_name_organization"),
|
||||
UniqueConstraint("name", "_user_id", name="uix_name_user"),
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
||||
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
|
||||
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
|
||||
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
|
||||
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
|
||||
json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
||||
module: Mapped[Optional[str]] = mapped_column(
|
||||
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
|
||||
)
|
||||
|
||||
# TODO: add terminal here eventually
|
||||
# This was an intentional decision by Sarah
|
||||
|
||||
# relationships
|
||||
# TODO: Possibly add in user in the future
|
||||
# This will require some more thought and justification to add this in.
|
||||
user: Mapped["User"] = relationship("User", back_populates="tools", lazy="selectin")
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tool import Tool
|
||||
|
||||
|
||||
class User(SqlalchemyBase, OrganizationMixin):
|
||||
"""User ORM class"""
|
||||
@@ -16,6 +21,7 @@ class User(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Add this back later potentially
|
||||
# agents: Mapped[List["Agent"]] = relationship(
|
||||
|
||||
@@ -10,19 +10,13 @@ from letta.functions.helpers import (
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
|
||||
|
||||
class BaseTool(LettaBase):
|
||||
__id_prefix__ = "tool"
|
||||
|
||||
# optional fields
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
|
||||
# optional: user_id (user-specific tools)
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the function.")
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""
|
||||
@@ -37,8 +31,12 @@ class Tool(BaseTool):
|
||||
|
||||
"""
|
||||
|
||||
id: str = BaseTool.generate_id_field()
|
||||
|
||||
id: str = Field(..., description="The id of the tool.")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
user_id: str = Field(..., description="The unique identifier of the user associated with the tool.")
|
||||
organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.")
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
tags: List[str] = Field(..., description="Metadata tags.")
|
||||
|
||||
@@ -58,14 +56,31 @@ class Tool(BaseTool):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ToolCreate(LettaBase):
|
||||
user_id: str = Field(UserManager.DEFAULT_USER_ID, description="The user that this tool belongs to. Defaults to the default user ID.")
|
||||
organization_id: str = Field(
|
||||
OrganizationManager.DEFAULT_ORG_ID,
|
||||
description="The organization that this tool belongs to. Defaults to the default organization ID.",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
module: Optional[str] = Field(None, description="The source code of the function.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: str = Field(..., description="The source type of the function.")
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
|
||||
|
||||
@classmethod
|
||||
def get_composio_tool(
|
||||
cls,
|
||||
action: "ActionType",
|
||||
) -> "Tool":
|
||||
def from_composio(
|
||||
cls, action: "ActionType", user_id: str = UserManager.DEFAULT_USER_ID, organization_id: str = OrganizationManager.DEFAULT_ORG_ID
|
||||
) -> "ToolCreate":
|
||||
"""
|
||||
Class method to create an instance of Letta-compatible Composio Tool.
|
||||
Check https://docs.composio.dev/introduction/intro/overview to look at options for get_composio_tool
|
||||
Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
|
||||
|
||||
This function will error if we find more than one tool, or 0 tools.
|
||||
|
||||
@@ -90,14 +105,9 @@ class Tool(BaseTool):
|
||||
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action)
|
||||
json_schema = generate_schema_from_args_schema(composio_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
json_schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
name=wrapper_func_name,
|
||||
description=description,
|
||||
source_type=source_type,
|
||||
@@ -107,7 +117,13 @@ class Tool(BaseTool):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool":
|
||||
def from_langchain(
|
||||
cls,
|
||||
langchain_tool: "LangChainBaseTool",
|
||||
additional_imports_module_attr_map: dict[str, str] = None,
|
||||
user_id: str = UserManager.DEFAULT_USER_ID,
|
||||
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
|
||||
) -> "ToolCreate":
|
||||
"""
|
||||
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
|
||||
|
||||
@@ -125,14 +141,9 @@ class Tool(BaseTool):
|
||||
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map)
|
||||
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
json_schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
name=wrapper_func_name,
|
||||
description=description,
|
||||
source_type=source_type,
|
||||
@@ -142,7 +153,13 @@ class Tool(BaseTool):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_crewai(cls, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool":
|
||||
def from_crewai(
|
||||
cls,
|
||||
crewai_tool: "CrewAIBaseTool",
|
||||
additional_imports_module_attr_map: dict[str, str] = None,
|
||||
user_id: str = UserManager.DEFAULT_USER_ID,
|
||||
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
|
||||
) -> "ToolCreate":
|
||||
"""
|
||||
Class method to create an instance of Tool from a crewAI BaseTool object.
|
||||
|
||||
@@ -158,14 +175,9 @@ class Tool(BaseTool):
|
||||
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool, additional_imports_module_attr_map)
|
||||
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
json_schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
name=wrapper_func_name,
|
||||
description=description,
|
||||
source_type=source_type,
|
||||
@@ -175,54 +187,43 @@ class Tool(BaseTool):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_default_langchain_tools(cls) -> List["Tool"]:
|
||||
def load_default_langchain_tools(cls) -> List["ToolCreate"]:
|
||||
# For now, we only support wikipedia tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
wikipedia_tool = Tool.from_langchain(
|
||||
wikipedia_tool = ToolCreate.from_langchain(
|
||||
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), {"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
)
|
||||
|
||||
return [wikipedia_tool]
|
||||
|
||||
@classmethod
|
||||
def load_default_crewai_tools(cls) -> List["Tool"]:
|
||||
def load_default_crewai_tools(cls) -> List["ToolCreate"]:
|
||||
# For now, we only support scrape website tool
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
web_scrape_tool = Tool.from_crewai(ScrapeWebsiteTool())
|
||||
web_scrape_tool = ToolCreate.from_crewai(ScrapeWebsiteTool())
|
||||
|
||||
return [web_scrape_tool]
|
||||
|
||||
@classmethod
|
||||
def load_default_composio_tools(cls) -> List["Tool"]:
|
||||
def load_default_composio_tools(cls) -> List["ToolCreate"]:
|
||||
from composio_langchain import Action
|
||||
|
||||
calculator = Tool.get_composio_tool(action=Action.MATHEMATICAL_CALCULATOR)
|
||||
serp_news = Tool.get_composio_tool(action=Action.SERPAPI_NEWS_SEARCH)
|
||||
serp_google_search = Tool.get_composio_tool(action=Action.SERPAPI_SEARCH)
|
||||
serp_google_maps = Tool.get_composio_tool(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH)
|
||||
calculator = ToolCreate.from_composio(action=Action.MATHEMATICAL_CALCULATOR)
|
||||
serp_news = ToolCreate.from_composio(action=Action.SERPAPI_NEWS_SEARCH)
|
||||
serp_google_search = ToolCreate.from_composio(action=Action.SERPAPI_SEARCH)
|
||||
serp_google_maps = ToolCreate.from_composio(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH)
|
||||
|
||||
return [calculator, serp_news, serp_google_search, serp_google_maps]
|
||||
|
||||
|
||||
class ToolCreate(BaseTool):
|
||||
id: Optional[str] = Field(None, description="The unique identifier of the tool. If this is not provided, it will be autogenerated.")
|
||||
name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
|
||||
|
||||
|
||||
class ToolUpdate(ToolCreate):
|
||||
id: str = Field(..., description="The unique identifier of the tool.")
|
||||
class ToolUpdate(LettaBase):
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
name: Optional[str] = Field(None, description="The name of the function.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
module: Optional[str] = Field(None, description="The source code of the function.")
|
||||
source_code: Optional[str] = Field(None, description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
|
||||
class UserBase(LettaBase):
|
||||
@@ -22,7 +22,7 @@ class User(UserBase):
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the user.")
|
||||
organization_id: Optional[str] = Field(DEFAULT_ORG_ID, description="The organization id of the user")
|
||||
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
|
||||
name: str = Field(..., description="The name of the user.")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="The update date of the user.")
|
||||
|
||||
@@ -31,7 +31,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
Create a new user in the database
|
||||
"""
|
||||
try:
|
||||
user = server.create_user(request)
|
||||
user = server.user_manager.create_user(request)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
@@ -13,13 +14,12 @@ router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
def delete_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# actor = server.get_user_or_default(user_id=user_id)
|
||||
server.delete_tool(tool_id=tool_id)
|
||||
server.tool_manager.delete_tool(tool_id=tool_id)
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=Tool, operation_id="get_tool")
|
||||
@@ -30,9 +30,7 @@ def get_tool(
|
||||
"""
|
||||
Get a tool by ID
|
||||
"""
|
||||
# actor = server.get_current_user()
|
||||
|
||||
tool = server.get_tool(tool_id=tool_id)
|
||||
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
|
||||
@@ -50,26 +48,26 @@ def get_tool_id(
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
tool_id = server.get_tool_id(tool_name, user_id=actor.id)
|
||||
if tool_id is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool_id
|
||||
try:
|
||||
tool = server.tool_manager.get_tool_by_name_and_org_id(tool_name=tool_name, organization_id=actor.organization_id)
|
||||
return tool.id
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} and organization id {actor.organization_id} not found.")
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Tool], operation_id="list_tools")
|
||||
def list_all_tools(
|
||||
def list_tools(
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id)
|
||||
return server.tool_manager.list_tools_for_org(organization_id=actor.organization_id, cursor=cursor, limit=limit)
|
||||
except Exception as e:
|
||||
# Log or print the full exception here for debugging
|
||||
print(f"Error occurred: {e}")
|
||||
@@ -78,21 +76,21 @@ def list_all_tools(
|
||||
|
||||
@router.post("/", response_model=Tool, operation_id="create_tool")
|
||||
def create_tool(
|
||||
tool: ToolCreate = Body(...),
|
||||
update: bool = False,
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
# Derive user and org id from actor
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
request.organization_id = actor.organization_id
|
||||
request.user_id = actor.id
|
||||
|
||||
return server.create_tool(
|
||||
request=tool,
|
||||
# update=update,
|
||||
update=True,
|
||||
user_id=actor.id,
|
||||
# Send request to create the tool
|
||||
return server.tool_manager.create_or_update_tool(
|
||||
tool_create=request,
|
||||
)
|
||||
|
||||
|
||||
@@ -106,6 +104,4 @@ def update_tool(
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
|
||||
# actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.update_tool(request, user_id)
|
||||
return server.tool_manager.update_tool_by_id(tool_id, request)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# inspecting tools
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
@@ -16,7 +14,6 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.agent_store.db import attach_base
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.client.utils import derive_function_name_regex
|
||||
from letta.credentials import LettaCredentials
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
|
||||
@@ -30,11 +27,7 @@ from letta.data_sources.connectors import DataConnector, load_data
|
||||
# Token,
|
||||
# User,
|
||||
# )
|
||||
from letta.functions.functions import (
|
||||
generate_schema,
|
||||
load_function_set,
|
||||
parse_source_code,
|
||||
)
|
||||
from letta.functions.functions import generate_schema, parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
|
||||
# TODO use custom interface
|
||||
@@ -82,10 +75,11 @@ from letta.schemas.memory import (
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User, UserCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import create_random_username, json_dumps, json_loads
|
||||
|
||||
@@ -214,6 +208,7 @@ class SyncServer(Server):
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[bool] = None,
|
||||
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
|
||||
init_with_default_org_and_user: bool = True,
|
||||
# default_interface: AgentInterface = CLIInterface(),
|
||||
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
||||
# auth_mode: str = "none", # "none, "jwt", "external"
|
||||
@@ -249,13 +244,19 @@ class SyncServer(Server):
|
||||
# Managers that interface with data models
|
||||
self.organization_manager = OrganizationManager()
|
||||
self.user_manager = UserManager()
|
||||
self.tool_manager = ToolManager()
|
||||
|
||||
# TODO: this should be removed
|
||||
# add global default tools (for admin)
|
||||
self.add_default_tools(module_name="base")
|
||||
# Make default user and org
|
||||
if init_with_default_org_and_user:
|
||||
self.default_org = self.organization_manager.create_default_organization()
|
||||
self.default_user = self.user_manager.create_default_user()
|
||||
self.add_default_blocks(self.default_user.id)
|
||||
self.tool_manager.add_default_tools(module_name="base", user_id=self.default_user.id, org_id=self.default_org.id)
|
||||
|
||||
if settings.load_default_external_tools:
|
||||
self.add_default_external_tools()
|
||||
# If there is a default org/user
|
||||
# This logic may have to change in the future
|
||||
if settings.load_default_external_tools:
|
||||
self.add_default_external_tools(user_id=self.default_user.id, org_id=self.default_org.id)
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||
@@ -364,7 +365,7 @@ class SyncServer(Server):
|
||||
logger.debug(f"Creating an agent object")
|
||||
tool_objs = []
|
||||
for name in agent_state.tools:
|
||||
tool_obj = self.ms.get_tool(tool_name=name, user_id=user_id)
|
||||
tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=name, user_id=user_id)
|
||||
if not tool_obj:
|
||||
logger.exception(f"Tool {name} does not exist for user {user_id}")
|
||||
raise ValueError(f"Tool {name} does not exist for user {user_id}")
|
||||
@@ -755,22 +756,6 @@ class SyncServer(Server):
|
||||
command = command[1:] # strip the prefix
|
||||
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
||||
|
||||
def create_user(self, request: UserCreate) -> User:
|
||||
"""Create a new user using a config"""
|
||||
if not request.name:
|
||||
# auto-generate a name
|
||||
request.name = create_random_username()
|
||||
user = self.user_manager.create_user(request)
|
||||
logger.debug(f"Created new user from config: {user}")
|
||||
|
||||
# add default for the user
|
||||
# TODO: move to org
|
||||
assert user.id is not None, f"User id is None: {user}"
|
||||
self.add_default_blocks(user.id)
|
||||
self.add_default_tools(module_name="base", user_id=user.id)
|
||||
|
||||
return user
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
@@ -816,8 +801,7 @@ class SyncServer(Server):
|
||||
tool_objs = []
|
||||
if request.tools:
|
||||
for tool_name in request.tools:
|
||||
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
assert tool_obj, f"Tool {tool_name} does not exist"
|
||||
tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id)
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
assert request.memory is not None
|
||||
@@ -832,16 +816,15 @@ class SyncServer(Server):
|
||||
json_schema = generate_schema(func, terminal=False, name=func_name)
|
||||
source_type = "python"
|
||||
tags = ["memory", "memgpt-base"]
|
||||
tool = self.create_tool(
|
||||
request=ToolCreate(
|
||||
tool = self.tool_manager.create_or_update_tool(
|
||||
ToolCreate(
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
tags=tags,
|
||||
json_schema=json_schema,
|
||||
user_id=user_id,
|
||||
),
|
||||
update=True,
|
||||
user_id=user_id,
|
||||
organization_id=user.organization_id,
|
||||
)
|
||||
)
|
||||
tool_objs.append(tool)
|
||||
if not request.tools:
|
||||
@@ -939,7 +922,7 @@ class SyncServer(Server):
|
||||
# (1) get tools + make sure they exist
|
||||
tool_objs = []
|
||||
for tool_name in request.tools:
|
||||
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id)
|
||||
assert tool_obj, f"Tool {tool_name} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
@@ -995,12 +978,12 @@ class SyncServer(Server):
|
||||
|
||||
# Get all the tool objects from the request
|
||||
tool_objs = []
|
||||
tool_obj = self.ms.get_tool(tool_id=tool_id, user_id=user_id)
|
||||
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id)
|
||||
assert tool_obj, f"Tool with id={tool_id} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
for tool in letta_agent.tools:
|
||||
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
|
||||
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id)
|
||||
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
||||
|
||||
# If it's not the already added tool
|
||||
@@ -1035,7 +1018,7 @@ class SyncServer(Server):
|
||||
# Get all the tool_objs
|
||||
tool_objs = []
|
||||
for tool in letta_agent.tools:
|
||||
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
|
||||
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id)
|
||||
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
||||
|
||||
# If it's not the tool we want to remove
|
||||
@@ -1076,86 +1059,6 @@ class SyncServer(Server):
|
||||
agents_states = self.ms.list_agents(user_id=user_id)
|
||||
return agents_states
|
||||
|
||||
# TODO make return type pydantic
|
||||
def list_agents_legacy(
|
||||
self,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
"""List all available agents to a user"""
|
||||
|
||||
if user_id is None:
|
||||
agents_states = self.ms.list_all_agents()
|
||||
else:
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
agents_states = self.ms.list_agents(user_id=user_id)
|
||||
|
||||
agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states]
|
||||
|
||||
# TODO add a get_message_obj_from_message_id(...) function
|
||||
# this would allow grabbing Message.created_by without having to load the agent object
|
||||
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
||||
self.ms.list_tools()
|
||||
|
||||
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self._get_or_load_agent(user_id=agent_state.user_id, agent_id=agent_state.id)
|
||||
|
||||
# TODO remove this eventually when return type get pydanticfied
|
||||
# this is to add persona_name and human_name so that the columns in UI can populate
|
||||
# TODO hack for frontend, remove
|
||||
# (top level .persona is persona_name, and nested memory.persona is the state)
|
||||
# TODO: eventually modify this to be contained in the metadata
|
||||
return_dict["persona"] = agent_state._metadata.get("persona", None)
|
||||
return_dict["human"] = agent_state._metadata.get("human", None)
|
||||
|
||||
# Add information about tools
|
||||
# TODO letta_agent should really have a field of List[ToolModel]
|
||||
# then we could just pull that field and return it here
|
||||
# return_dict["tools"] = [tool for tool in all_available_tools if tool.json_schema in letta_agent.functions]
|
||||
|
||||
# get tool info from agent state
|
||||
tools = []
|
||||
for tool_name in agent_state.tools:
|
||||
tool = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
tools.append(tool)
|
||||
return_dict["tools"] = tools
|
||||
|
||||
# Add information about memory (raw core, size of recall, size of archival)
|
||||
core_memory = letta_agent.memory
|
||||
recall_memory = letta_agent.persistence_manager.recall_memory
|
||||
archival_memory = letta_agent.persistence_manager.archival_memory
|
||||
memory_obj = {
|
||||
"core_memory": core_memory.to_flat_dict(),
|
||||
"recall_memory": len(recall_memory) if recall_memory is not None else None,
|
||||
"archival_memory": len(archival_memory) if archival_memory is not None else None,
|
||||
}
|
||||
return_dict["memory"] = memory_obj
|
||||
|
||||
# Add information about last run
|
||||
# NOTE: 'last_run' is just the timestamp on the latest message in the buffer
|
||||
# Retrieve the Message object via the recall storage or by directly access _messages
|
||||
last_msg_obj = letta_agent._messages[-1]
|
||||
return_dict["last_run"] = last_msg_obj.created_at
|
||||
|
||||
# Add information about attached sources
|
||||
sources_ids = self.ms.list_attached_sources(agent_id=agent_state.id)
|
||||
sources = [self.ms.get_source(source_id=s_id) for s_id in sources_ids]
|
||||
return_dict["sources"] = [vars(s) for s in sources]
|
||||
|
||||
# Sort agents by "last_run" in descending order, most recent first
|
||||
agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True)
|
||||
|
||||
logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}")
|
||||
return {
|
||||
"num_agents": len(agents_states),
|
||||
"agents": agents_states_dicts,
|
||||
}
|
||||
|
||||
# blocks
|
||||
|
||||
def get_blocks(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
@@ -1830,195 +1733,17 @@ class SyncServer(Server):
|
||||
|
||||
return sources_with_metadata
|
||||
|
||||
def get_tool(self, tool_id: str) -> Optional[Tool]:
|
||||
"""Get tool by ID."""
|
||||
return self.ms.get_tool(tool_id=tool_id)
|
||||
|
||||
def tool_with_name_and_user_id_exists(self, tool: Tool, user_id: Optional[str] = None) -> bool:
|
||||
"""Check if tool exists"""
|
||||
tool = self.ms.get_tool_with_name_and_user_id(tool_name=tool.name, user_id=user_id)
|
||||
|
||||
if tool is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
|
||||
"""Get tool ID from name and user_id."""
|
||||
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
|
||||
if not tool or tool.id is None:
|
||||
return None
|
||||
return tool.id
|
||||
|
||||
def update_tool(self, request: ToolUpdate, user_id: Optional[str] = None) -> Tool:
|
||||
"""Update an existing tool"""
|
||||
if request.name:
|
||||
existing_tool = self.ms.get_tool_with_name_and_user_id(tool_name=request.name, user_id=user_id)
|
||||
if existing_tool is None:
|
||||
raise ValueError(f"Tool with name={request.name}, user_id={user_id} does not exist")
|
||||
else:
|
||||
existing_tool = self.ms.get_tool(tool_id=request.id)
|
||||
if existing_tool is None:
|
||||
raise ValueError(f"Tool with id={request.id} does not exist")
|
||||
|
||||
# Preserve the original tool id
|
||||
# As we can override the tool id as well
|
||||
# This is probably bad design if this is exposed to users...
|
||||
original_id = existing_tool.id
|
||||
|
||||
# override updated fields
|
||||
if request.id:
|
||||
existing_tool.id = request.id
|
||||
if request.description:
|
||||
existing_tool.description = request.description
|
||||
if request.source_code:
|
||||
existing_tool.source_code = request.source_code
|
||||
if request.source_type:
|
||||
existing_tool.source_type = request.source_type
|
||||
if request.tags:
|
||||
existing_tool.tags = request.tags
|
||||
if request.json_schema:
|
||||
existing_tool.json_schema = request.json_schema
|
||||
|
||||
# If name is explicitly provided here, overide the tool name
|
||||
if request.name:
|
||||
existing_tool.name = request.name
|
||||
# Otherwise, if there's no name, and there's source code, we try to derive the name
|
||||
elif request.source_code:
|
||||
existing_tool.name = derive_function_name_regex(request.source_code)
|
||||
|
||||
self.ms.update_tool(original_id, existing_tool)
|
||||
return self.ms.get_tool(tool_id=request.id)
|
||||
|
||||
def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields
|
||||
"""Create a new tool"""
|
||||
|
||||
# NOTE: deprecated code that existed when we were trying to pretend that `self` was the memory object
|
||||
# if request.tags and "memory" in request.tags:
|
||||
# # special modifications to memory functions
|
||||
# # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
|
||||
# request.source_code = request.source_code.replace("self.memory", "self.memory.memory")
|
||||
|
||||
if not request.json_schema:
|
||||
# auto-generate openai schema
|
||||
try:
|
||||
env = {}
|
||||
env.update(globals())
|
||||
exec(request.source_code, env)
|
||||
|
||||
# get available functions
|
||||
functions = [f for f in env if callable(env[f])]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute source code: {e}")
|
||||
|
||||
# TODO: not sure if this always works
|
||||
func = env[functions[-1]]
|
||||
json_schema = generate_schema(func, terminal=request.terminal)
|
||||
else:
|
||||
# provided by client
|
||||
json_schema = request.json_schema
|
||||
|
||||
if not request.name:
|
||||
# use name from JSON schema
|
||||
request.name = json_schema["name"]
|
||||
assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen."
|
||||
|
||||
# check if already exists:
|
||||
existing_tool = self.ms.get_tool(tool_id=request.id, tool_name=request.name, user_id=user_id)
|
||||
if existing_tool:
|
||||
if update:
|
||||
# id is an optional field, so we will fill it with the existing tool id
|
||||
if not request.id:
|
||||
request.id = existing_tool.id
|
||||
updated_tool = self.update_tool(ToolUpdate(**vars(request)), user_id)
|
||||
assert updated_tool is not None, f"Failed to update tool {request.name}"
|
||||
return updated_tool
|
||||
else:
|
||||
raise ValueError(f"Tool {request.name} already exists and update=False")
|
||||
|
||||
# check for description
|
||||
description = None
|
||||
if request.description:
|
||||
description = request.description
|
||||
|
||||
tool = Tool(
|
||||
name=request.name,
|
||||
source_code=request.source_code,
|
||||
source_type=request.source_type,
|
||||
tags=request.tags,
|
||||
json_schema=json_schema,
|
||||
user_id=user_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
if request.id:
|
||||
tool.id = request.id
|
||||
|
||||
self.ms.create_tool(tool)
|
||||
created_tool = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
|
||||
return created_tool
|
||||
|
||||
def delete_tool(self, tool_id: str):
|
||||
"""Delete a tool"""
|
||||
self.ms.delete_tool(tool_id)
|
||||
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]:
|
||||
"""List tools available to user_id"""
|
||||
tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id)
|
||||
return tools
|
||||
|
||||
def add_default_tools(self, module_name="base", user_id: Optional[str] = None):
|
||||
"""Add default tools in {module_name}.py"""
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
functions_to_schema = []
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema = load_function_set(module)
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
warnings.warn(err)
|
||||
|
||||
# create tool in db
|
||||
for name, schema in functions_to_schema.items():
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("letta-base")
|
||||
|
||||
# create to tool
|
||||
self.create_tool(
|
||||
ToolCreate(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
user_id=user_id,
|
||||
),
|
||||
update=True,
|
||||
)
|
||||
|
||||
def add_default_external_tools(self, user_id: Optional[str] = None) -> bool:
|
||||
def add_default_external_tools(self, user_id: str, org_id: str) -> bool:
|
||||
"""Add default langchain tools. Return true if successful, false otherwise."""
|
||||
success = True
|
||||
tool_creates = ToolCreate.load_default_langchain_tools() + ToolCreate.load_default_crewai_tools()
|
||||
if tool_settings.composio_api_key:
|
||||
tools = Tool.load_default_langchain_tools() + Tool.load_default_crewai_tools() + Tool.load_default_composio_tools()
|
||||
else:
|
||||
tools = Tool.load_default_langchain_tools() + Tool.load_default_crewai_tools()
|
||||
for tool in tools:
|
||||
tool_creates += ToolCreate.load_default_composio_tools()
|
||||
for tool_create in tool_creates:
|
||||
try:
|
||||
self.ms.create_tool(tool)
|
||||
self.tool_manager.create_or_update_tool(tool_create)
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while creating tool {tool}: {e}")
|
||||
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
|
||||
warnings.warn(traceback.format_exc())
|
||||
success = False
|
||||
|
||||
@@ -2108,25 +1833,15 @@ class SyncServer(Server):
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
return letta_agent.retry_message()
|
||||
|
||||
# TODO: Move a lot of this default logic to the ORM
|
||||
def get_default_user(self) -> User:
|
||||
self.organization_manager.create_default_organization()
|
||||
user = self.user_manager.create_default_user()
|
||||
|
||||
self.add_default_blocks(user.id)
|
||||
self.add_default_tools(module_name="base", user_id=user.id)
|
||||
|
||||
return user
|
||||
|
||||
def get_user_or_default(self, user_id: Optional[str]) -> User:
|
||||
"""Get the user object for user_id if it exists, otherwise return the default user object"""
|
||||
if user_id is None:
|
||||
return self.get_default_user()
|
||||
else:
|
||||
try:
|
||||
return self.user_manager.get_user_by_id(user_id=user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
|
||||
user_id = self.user_manager.DEFAULT_USER_ID
|
||||
|
||||
try:
|
||||
return self.user_manager.get_user_by_id(user_id=user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.utils import create_random_username, enforce_types
|
||||
|
||||
@@ -10,6 +9,9 @@ from letta.utils import create_random_username, enforce_types
|
||||
class OrganizationManager:
|
||||
"""Manager class to handle business logic related to Organizations."""
|
||||
|
||||
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_ORG_NAME = "default_org"
|
||||
|
||||
def __init__(self):
|
||||
# This is probably horrible but we reuse this technique from metadata.py
|
||||
# TODO: Please refactor this out
|
||||
@@ -19,12 +21,17 @@ class OrganizationManager:
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_default_organization(self) -> PydanticOrganization:
|
||||
"""Fetch the default organization."""
|
||||
return self.get_organization_by_id(self.DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
def get_organization_by_id(self, org_id: str) -> PydanticOrganization:
|
||||
"""Fetch an organization by ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
organization = Organization.read(db_session=session, identifier=org_id)
|
||||
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
return organization.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Organization with id {org_id} not found.")
|
||||
@@ -33,7 +40,7 @@ class OrganizationManager:
|
||||
def create_organization(self, name: Optional[str] = None) -> PydanticOrganization:
|
||||
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
|
||||
with self.session_maker() as session:
|
||||
org = Organization(name=name if name else create_random_username())
|
||||
org = OrganizationModel(name=name if name else create_random_username())
|
||||
org.create(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@@ -43,10 +50,10 @@ class OrganizationManager:
|
||||
with self.session_maker() as session:
|
||||
# Try to get it first
|
||||
try:
|
||||
org = Organization.read(db_session=session, identifier=DEFAULT_ORG_ID)
|
||||
org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID)
|
||||
# If it doesn't exist, make it
|
||||
except NoResultFound:
|
||||
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
|
||||
org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
|
||||
org.create(session)
|
||||
|
||||
return org.to_pydantic()
|
||||
@@ -55,22 +62,22 @@ class OrganizationManager:
|
||||
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with self.session_maker() as session:
|
||||
organization = Organization.read(db_session=session, identifier=org_id)
|
||||
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
if name:
|
||||
organization.name = name
|
||||
organization.update(session)
|
||||
return organization.to_pydantic()
|
||||
org.name = name
|
||||
org.update(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_organization_by_id(self, org_id: str):
|
||||
"""Delete an organization by marking it as deleted."""
|
||||
with self.session_maker() as session:
|
||||
organization = Organization.read(db_session=session, identifier=org_id)
|
||||
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
organization.delete(session)
|
||||
|
||||
@enforce_types
|
||||
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|
||||
"""List organizations with pagination based on cursor (org_id) and limit."""
|
||||
with self.session_maker() as session:
|
||||
results = Organization.list(db_session=session, cursor=cursor, limit=limit)
|
||||
results = OrganizationModel.list(db_session=session, cursor=cursor, limit=limit)
|
||||
return [org.to_pydantic() for org in results]
|
||||
|
||||
193
letta/services/tool_manager.py
Normal file
193
letta/services/tool_manager.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
|
||||
# TODO: Remove this once we translate all of these to the ORM
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.orm.user import User as UserModel
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manager class to handle business logic related to Tools."""
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_or_update_tool(self, tool_create: ToolCreate) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Derive json_schema
|
||||
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create)
|
||||
derived_name = tool_create.name or derived_json_schema["name"]
|
||||
|
||||
try:
|
||||
# NOTE: We use the organization id here
|
||||
# This is important, because even if it's a different user, adding the same tool to the org should not happen
|
||||
tool = self.get_tool_by_name_and_org_id(tool_name=derived_name, organization_id=tool_create.organization_id)
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = tool_create.model_dump(exclude={"user_id", "organization_id", "module", "terminal"}, exclude_unset=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
|
||||
|
||||
# If there's anything to update
|
||||
if update_data:
|
||||
self.update_tool_by_id(tool.id, ToolUpdate(**update_data))
|
||||
else:
|
||||
warnings.warn(
|
||||
f"`create_or_update_tool` was called with user_id={tool_create.user_id}, organization_id={tool_create.organization_id}, name={tool_create.name}, but found existing tool with nothing to update."
|
||||
)
|
||||
except NoResultFound:
|
||||
tool_create.json_schema = derived_json_schema
|
||||
tool_create.name = derived_name
|
||||
tool = self.create_tool(tool_create)
|
||||
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
def create_tool(self, tool_create: ToolCreate) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Create the tool
|
||||
with self.session_maker() as session:
|
||||
# Include all fields except 'terminal' (which is not part of ToolModel) at the moment
|
||||
create_data = tool_create.model_dump(exclude={"terminal"})
|
||||
tool = ToolModel(**create_data) # Unpack everything directly into ToolModel
|
||||
tool.create(session)
|
||||
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_tool_by_id(self, tool_id: str) -> PydanticTool:
|
||||
"""Fetch a tool by its ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Retrieve tool by id using the Tool model's read method
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id)
|
||||
# Convert the SQLAlchemy Tool object to PydanticTool
|
||||
return tool.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def get_tool_by_name_and_user_id(self, tool_name: str, user_id: str) -> PydanticTool:
|
||||
"""Retrieve a tool by its name and organization_id."""
|
||||
with self.session_maker() as session:
|
||||
# Use the list method to apply filters
|
||||
results = ToolModel.list(db_session=session, name=tool_name, _user_id=UserModel.get_uid_from_identifier(user_id))
|
||||
|
||||
# Ensure only one result is returned (since there is a unique constraint)
|
||||
if not results:
|
||||
raise NoResultFound(f"Tool with name {tool_name} and user_id {user_id} not found.")
|
||||
|
||||
if len(results) > 1:
|
||||
raise RuntimeError(
|
||||
f"Multiple tools with name {tool_name} and user_id {user_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error."
|
||||
)
|
||||
|
||||
# Return the single result
|
||||
return results[0]
|
||||
|
||||
@enforce_types
|
||||
def get_tool_by_name_and_org_id(self, tool_name: str, organization_id: str) -> PydanticTool:
|
||||
"""Retrieve a tool by its name and organization_id."""
|
||||
with self.session_maker() as session:
|
||||
# Use the list method to apply filters
|
||||
results = ToolModel.list(
|
||||
db_session=session, name=tool_name, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id)
|
||||
)
|
||||
|
||||
# Ensure only one result is returned (since there is a unique constraint)
|
||||
if not results:
|
||||
raise NoResultFound(f"Tool with name {tool_name} and organization_id {organization_id} not found.")
|
||||
|
||||
if len(results) > 1:
|
||||
raise RuntimeError(
|
||||
f"Multiple tools with name {tool_name} and organization_id {organization_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error."
|
||||
)
|
||||
|
||||
# Return the single result
|
||||
return results[0]
|
||||
|
||||
@enforce_types
|
||||
def list_tools_for_org(self, organization_id: str, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination using cursor and limit."""
|
||||
with self.session_maker() as session:
|
||||
tools = ToolModel.list(
|
||||
db_session=session, cursor=cursor, limit=limit, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id)
|
||||
)
|
||||
return [tool.to_pydantic() for tool in tools]
|
||||
|
||||
@enforce_types
|
||||
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate) -> None:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
with self.session_maker() as session:
|
||||
# Fetch the tool by ID
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = tool_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(tool, key, value)
|
||||
|
||||
# Save the updated tool to the database
|
||||
tool.update(db_session=session)
|
||||
|
||||
@enforce_types
|
||||
def delete_tool_by_id(self, tool_id: str) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id)
|
||||
tool.delete(db_session=session)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def add_default_tools(self, user_id: str, org_id: str, module_name="base"):
|
||||
"""Add default tools in {module_name}.py"""
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
functions_to_schema = []
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema = load_function_set(module)
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
warnings.warn(err)
|
||||
|
||||
# create tool in db
|
||||
for name, schema in functions_to_schema.items():
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("letta-base")
|
||||
|
||||
# create to tool
|
||||
self.create_or_update_tool(
|
||||
ToolCreate(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
organization_id=org_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
)
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME
|
||||
|
||||
# TODO: Remove this once we translate all of these to the ORM
|
||||
from letta.metadata import AgentModel, AgentSourceMappingModel, SourceModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.user import User as UserModel
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserCreate, UserUpdate
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""Manager class to handle business logic related to Users."""
|
||||
|
||||
DEFAULT_USER_NAME = "default_user"
|
||||
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.server import db_context
|
||||
@@ -22,7 +22,7 @@ class UserManager:
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
|
||||
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
|
||||
"""Create the default user."""
|
||||
with self.session_maker() as session:
|
||||
# Make sure the org id exists
|
||||
@@ -33,10 +33,10 @@ class UserManager:
|
||||
|
||||
# Try to retrieve the user
|
||||
try:
|
||||
user = UserModel.read(db_session=session, identifier=DEFAULT_USER_ID)
|
||||
user = UserModel.read(db_session=session, identifier=self.DEFAULT_USER_ID)
|
||||
except NoResultFound:
|
||||
# If it doesn't exist, make it
|
||||
user = UserModel(id=DEFAULT_USER_ID, name=DEFAULT_USER_NAME, organization_id=org_id)
|
||||
user = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
|
||||
user.create(session)
|
||||
|
||||
return user.to_pydantic()
|
||||
@@ -73,11 +73,11 @@ class UserManager:
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
user.delete(session)
|
||||
|
||||
# TODO: Remove this once we have ORM models for the Agent, Source, and AgentSourceMapping
|
||||
# TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping
|
||||
# Cascade delete for related models: Agent, Source, AgentSourceMapping
|
||||
session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
||||
session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
||||
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
||||
# session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
||||
# session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
||||
# session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
||||
|
||||
session.commit()
|
||||
|
||||
@@ -91,6 +91,11 @@ class UserManager:
|
||||
except NoResultFound:
|
||||
raise ValueError(f"User with id {user_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def get_default_user(self) -> PydanticUser:
|
||||
"""Fetch the default user."""
|
||||
return self.get_user_by_id(self.DEFAULT_USER_ID)
|
||||
|
||||
@enforce_types
|
||||
def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]:
|
||||
"""List users with pagination using cursor (id) and limit."""
|
||||
|
||||
@@ -165,16 +165,13 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
|
||||
"""
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
|
||||
tool = Tool.from_crewai(crewai_tool)
|
||||
tool_name = tool.name
|
||||
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
client.add_tool(tool)
|
||||
tool = client.load_crewai_tool(crewai_tool=crewai_tool)
|
||||
tool_name = tool.name
|
||||
|
||||
# Set up persona for tool usage
|
||||
persona = f"""
|
||||
|
||||
@@ -41,17 +41,17 @@ def test_letta_run_create_new_agent(swap_letta_config):
|
||||
child = pexpect.spawn("poetry run letta run", encoding="utf-8")
|
||||
# Start the letta run command
|
||||
child.logfile = sys.stdout
|
||||
child.expect("Creating new agent", timeout=10)
|
||||
child.expect("Creating new agent", timeout=20)
|
||||
# Optional: LLM model selection
|
||||
try:
|
||||
child.expect("Select LLM model:", timeout=10)
|
||||
child.sendline("\033[B\033[B\033[B\033[B\033[B")
|
||||
child.expect("Select LLM model:", timeout=20)
|
||||
child.sendline("")
|
||||
except (pexpect.TIMEOUT, pexpect.EOF):
|
||||
print("[WARNING] LLM model selection step was skipped.")
|
||||
|
||||
# Optional: Embedding model selection
|
||||
try:
|
||||
child.expect("Select embedding model:", timeout=10)
|
||||
child.expect("Select embedding model:", timeout=20)
|
||||
child.sendline("text-embedding-ada-002")
|
||||
except (pexpect.TIMEOUT, pexpect.EOF):
|
||||
print("[WARNING] Embedding model selection step was skipped.")
|
||||
|
||||
@@ -262,15 +262,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
|
||||
assert human.value == "Human text", "Creating human failed"
|
||||
|
||||
|
||||
# def test_tools(client, agent):
|
||||
# tools_response = client.list_tools()
|
||||
# print("TOOLS", tools_response)
|
||||
#
|
||||
# tool_name = "TestTool"
|
||||
# tool_response = client.create_tool(name=tool_name, source_code="print('Hello World')", source_type="python")
|
||||
# assert tool_response, "Creating tool failed"
|
||||
|
||||
|
||||
def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.client.client import LocalClient
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool import ToolCreate
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -35,7 +34,7 @@ def agent(client):
|
||||
assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}"
|
||||
|
||||
|
||||
def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
def test_agent(client: LocalClient):
|
||||
# create agent
|
||||
agent_state_test = client.create_agent(
|
||||
name="test_agent2",
|
||||
@@ -120,18 +119,16 @@ def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_state_test.id)
|
||||
|
||||
|
||||
def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
|
||||
def test_agent_add_remove_tools(client: LocalClient, agent):
|
||||
# Create and add two tools to the client
|
||||
# tool 1
|
||||
from composio_langchain import Action
|
||||
|
||||
github_tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
client.add_tool(github_tool)
|
||||
github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
# tool 2
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
scrape_website_tool = Tool.from_crewai(ScrapeWebsiteTool(website_url="https://www.example.com"))
|
||||
client.add_tool(scrape_website_tool)
|
||||
scrape_website_tool = client.load_crewai_tool(crewai_tool=ScrapeWebsiteTool(website_url="https://www.example.com"))
|
||||
|
||||
# assert both got added
|
||||
tools = client.list_tools()
|
||||
@@ -171,7 +168,7 @@ def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
|
||||
assert scrape_website_tool.name in curr_tool_names
|
||||
|
||||
|
||||
def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
|
||||
def test_agent_with_shared_blocks(client: LocalClient):
|
||||
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
|
||||
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
|
||||
existing_non_template_blocks = [persona_block, human_block]
|
||||
@@ -222,7 +219,7 @@ def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(second_agent_state_test.id)
|
||||
|
||||
|
||||
def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_memory(client: LocalClient, agent: AgentState):
|
||||
# get agent memory
|
||||
original_memory = client.get_in_context_memory(agent.id)
|
||||
assert original_memory is not None
|
||||
@@ -235,7 +232,7 @@ def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
|
||||
|
||||
|
||||
def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_archival_memory(client: LocalClient, agent: AgentState):
|
||||
"""Test functions for interacting with archival memory store"""
|
||||
|
||||
# add archival memory
|
||||
@@ -250,7 +247,7 @@ def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentSta
|
||||
client.delete_archival_memory(agent.id, passage.id)
|
||||
|
||||
|
||||
def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_recall_memory(client: LocalClient, agent: AgentState):
|
||||
"""Test functions for interacting with recall memory store"""
|
||||
|
||||
# send message to the agent
|
||||
@@ -274,7 +271,7 @@ def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState
|
||||
assert exists
|
||||
|
||||
|
||||
def test_tools(client: Union[LocalClient, RESTClient]):
|
||||
def test_tools(client: LocalClient):
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
A tool to print a message
|
||||
@@ -314,20 +311,18 @@ def test_tools(client: Union[LocalClient, RESTClient]):
|
||||
assert client.get_tool(tool.id).tags == extras2
|
||||
|
||||
# update tool: source code
|
||||
client.update_tool(tool.id, func=print_tool2)
|
||||
client.update_tool(tool.id, name="print_tool2", func=print_tool2)
|
||||
assert client.get_tool(tool.id).name == "print_tool2"
|
||||
|
||||
|
||||
def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
|
||||
def test_tools_from_composio_basic(client: LocalClient):
|
||||
from composio_langchain import Action
|
||||
|
||||
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
|
||||
client = create_client()
|
||||
|
||||
tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
|
||||
# create tool
|
||||
client.add_tool(tool)
|
||||
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
@@ -337,18 +332,15 @@ def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
|
||||
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
|
||||
|
||||
|
||||
def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
|
||||
def test_tools_from_crewai(client: LocalClient):
|
||||
# create crewAI tool
|
||||
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool()
|
||||
|
||||
# Translate to memGPT Tool
|
||||
tool = Tool.from_crewai(crewai_tool)
|
||||
|
||||
# Add the tool
|
||||
client.add_tool(tool)
|
||||
tool = client.load_crewai_tool(crewai_tool=crewai_tool)
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
@@ -372,18 +364,15 @@ def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
|
||||
assert expected_content in func(website_url=simple_webpage_url)
|
||||
|
||||
|
||||
def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
|
||||
def test_tools_from_crewai_with_params(client: LocalClient):
|
||||
# create crewAI tool
|
||||
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
|
||||
|
||||
# Translate to memGPT Tool
|
||||
tool = Tool.from_crewai(crewai_tool)
|
||||
|
||||
# Add the tool
|
||||
client.add_tool(tool)
|
||||
tool = client.load_crewai_tool(crewai_tool=crewai_tool)
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
@@ -404,7 +393,7 @@ def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
|
||||
assert expected_content in func()
|
||||
|
||||
|
||||
def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
|
||||
def test_tools_from_langchain(client: LocalClient):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
@@ -412,11 +401,10 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
# Translate to memGPT Tool
|
||||
tool = Tool.from_langchain(langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"})
|
||||
|
||||
# Add the tool
|
||||
client.add_tool(tool)
|
||||
tool = client.load_langchain_tool(
|
||||
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
)
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
@@ -436,7 +424,7 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
|
||||
assert expected_content in func(query="Albert Einstein")
|
||||
|
||||
|
||||
def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]):
|
||||
def test_tool_creation_langchain_missing_imports(client: LocalClient):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
@@ -447,4 +435,4 @@ def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, REST
|
||||
# Translate to memGPT Tool
|
||||
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
with pytest.raises(RuntimeError):
|
||||
Tool.from_langchain(langchain_tool)
|
||||
ToolCreate.from_langchain(langchain_tool)
|
||||
@@ -2,14 +2,12 @@ import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import (
|
||||
DEFAULT_ORG_ID,
|
||||
DEFAULT_ORG_NAME,
|
||||
DEFAULT_USER_ID,
|
||||
DEFAULT_USER_NAME,
|
||||
)
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
utils.DEBUG = True
|
||||
from letta.config import LettaConfig
|
||||
@@ -18,21 +16,58 @@ from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_organization_and_user_table(server: SyncServer):
|
||||
def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Tool)) # Clear all records from the Tool table
|
||||
session.execute(delete(User)) # Clear all records from the user table
|
||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||
session.commit() # Commit the deletion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fixture(server: SyncServer):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
Returns:
|
||||
str: The message that was printed.
|
||||
"""
|
||||
print(message)
|
||||
return message
|
||||
|
||||
source_code = parse_source_code(print_tool)
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
tags = ["test"]
|
||||
|
||||
org = server.organization_manager.create_default_organization()
|
||||
user = server.user_manager.create_default_user()
|
||||
tool_create = ToolCreate(
|
||||
user_id=user.id, organization_id=org.id, description=description, tags=tags, source_code=source_code, source_type=source_type
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(tool_create)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool_create.json_schema = derived_json_schema
|
||||
tool_create.name = derived_name
|
||||
|
||||
tool = server.tool_manager.create_tool(tool_create)
|
||||
|
||||
# Yield the created tool, organization, and user for use in tests
|
||||
yield {"tool": tool, "organization": org, "user": user, "tool_create": tool_create}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
|
||||
|
||||
@@ -55,8 +90,8 @@ def test_list_organizations(server: SyncServer):
|
||||
|
||||
def test_create_default_organization(server: SyncServer):
|
||||
server.organization_manager.create_default_organization()
|
||||
retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
|
||||
assert retrieved.name == DEFAULT_ORG_NAME
|
||||
retrieved = server.organization_manager.get_default_organization()
|
||||
assert retrieved.name == server.organization_manager.DEFAULT_ORG_NAME
|
||||
|
||||
|
||||
def test_update_organization_name(server: SyncServer):
|
||||
@@ -105,8 +140,8 @@ def test_list_users(server: SyncServer):
|
||||
def test_create_default_user(server: SyncServer):
|
||||
org = server.organization_manager.create_default_organization()
|
||||
server.user_manager.create_default_user(org_id=org.id)
|
||||
retrieved = server.user_manager.get_user_by_id(DEFAULT_USER_ID)
|
||||
assert retrieved.name == DEFAULT_USER_NAME
|
||||
retrieved = server.user_manager.get_default_user()
|
||||
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME
|
||||
|
||||
|
||||
def test_update_user(server: SyncServer):
|
||||
@@ -124,9 +159,117 @@ def test_update_user(server: SyncServer):
|
||||
# Adjust name
|
||||
user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == DEFAULT_ORG_ID
|
||||
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
|
||||
|
||||
# Adjust org id
|
||||
user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == test_org.id
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tool Manager Tests
|
||||
# ======================================================================================================================
|
||||
def test_create_tool(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
tool_create = tool_fixture["tool_create"]
|
||||
user = tool_fixture["user"]
|
||||
org = tool_fixture["organization"]
|
||||
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert tool.user_id == user.id
|
||||
assert tool.organization_id == org.id
|
||||
assert tool.description == tool_create.description
|
||||
assert tool.tags == tool_create.tags
|
||||
assert tool.source_code == tool_create.source_code
|
||||
assert tool.source_type == tool_create.source_type
|
||||
assert tool.json_schema == derive_openai_json_schema(tool_create)
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
|
||||
# Fetch the tool by ID using the manager method
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(tool.id)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
assert fetched_tool.name == tool.name
|
||||
assert fetched_tool.description == tool.description
|
||||
assert fetched_tool.tags == tool.tags
|
||||
assert fetched_tool.source_code == tool.source_code
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
|
||||
|
||||
def test_get_tool_by_name_and_org_id(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
org = tool_fixture["organization"]
|
||||
|
||||
# Fetch the tool by name and organization ID
|
||||
fetched_tool = server.tool_manager.get_tool_by_name_and_org_id(tool.name, org.id)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
assert fetched_tool.name == tool.name
|
||||
assert fetched_tool.organization_id == org.id
|
||||
assert fetched_tool.description == tool.description
|
||||
assert fetched_tool.tags == tool.tags
|
||||
assert fetched_tool.source_code == tool.source_code
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
|
||||
|
||||
def test_get_tool_by_name_and_user_id(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# Fetch the tool by name and organization ID
|
||||
fetched_tool = server.tool_manager.get_tool_by_name_and_user_id(tool.name, user.id)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
assert fetched_tool.name == tool.name
|
||||
assert fetched_tool.user_id == user.id
|
||||
assert fetched_tool.description == tool.description
|
||||
assert fetched_tool.tags == tool.tags
|
||||
assert fetched_tool.source_code == tool.source_code
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
|
||||
|
||||
def test_list_tools(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
org = tool_fixture["organization"]
|
||||
|
||||
# List tools (should include the one created by the fixture)
|
||||
tools = server.tool_manager.list_tools_for_org(organization_id=org.id)
|
||||
|
||||
# Assertions to check that the created tool is listed
|
||||
assert len(tools) == 1
|
||||
assert any(t.id == tool.id for t in tools)
|
||||
|
||||
|
||||
def test_update_tool_by_id(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
updated_description = "updated_description"
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's description
|
||||
tool_update = ToolUpdate(description=updated_description)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id)
|
||||
|
||||
# Assertions to check if the update was successful
|
||||
assert updated_tool.description == updated_description
|
||||
|
||||
|
||||
def test_delete_tool_by_id(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
org = tool_fixture["organization"]
|
||||
|
||||
# Delete the tool using the manager method
|
||||
server.tool_manager.delete_tool_by_id(tool.id)
|
||||
|
||||
tools = server.tool_manager.list_tools_for_org(organization_id=org.id)
|
||||
assert len(tools) == 0
|
||||
|
||||
@@ -25,7 +25,6 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import SourceCreate
|
||||
from letta.schemas.user import UserCreate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
from .utils import DummyDataConnector
|
||||
@@ -33,26 +32,9 @@ from .utils import DummyDataConnector
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# create_config("openai")
|
||||
# credentials = LettaCredentials(
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# )
|
||||
# else: # hosted
|
||||
# create_config("letta_hosted")
|
||||
# credentials = LettaCredentials()
|
||||
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
## set to use postgres
|
||||
# config.archival_storage_uri = db_url
|
||||
# config.recall_storage_uri = db_url
|
||||
# config.metadata_storage_uri = db_url
|
||||
# config.archival_storage_type = "postgres"
|
||||
# config.recall_storage_type = "postgres"
|
||||
# config.metadata_storage_type = "postgres"
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
@@ -62,7 +44,7 @@ def server():
|
||||
@pytest.fixture(scope="module")
|
||||
def org_id(server):
|
||||
# create org
|
||||
org = server.organization_manager.create_organization(name="test_org")
|
||||
org = server.organization_manager.create_default_organization()
|
||||
print(f"Created org\n{org.id}")
|
||||
|
||||
yield org.id
|
||||
@@ -74,7 +56,7 @@ def org_id(server):
|
||||
@pytest.fixture(scope="module")
|
||||
def user_id(server, org_id):
|
||||
# create user
|
||||
user = server.create_user(UserCreate(name="test_user", organization_id=org_id))
|
||||
user = server.user_manager.create_default_user()
|
||||
print(f"Created user\n{user.id}")
|
||||
|
||||
yield user.id
|
||||
|
||||
@@ -56,13 +56,13 @@ def client(request):
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
else:
|
||||
server_url = None
|
||||
assert False, "Local client not implemented"
|
||||
|
||||
assert server_url is not None
|
||||
client = create_client(base_url=server_url) # This yields control back to the test function
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
# Clear all records from the Tool table
|
||||
yield client
|
||||
|
||||
|
||||
@@ -93,7 +93,16 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
return message
|
||||
|
||||
tools = client.list_tools()
|
||||
print(f"Original tools {[t.name for t in tools]}")
|
||||
assert sorted([t.name for t in tools]) == sorted(
|
||||
[
|
||||
"archival_memory_search",
|
||||
"send_message",
|
||||
"pause_heartbeats",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
]
|
||||
)
|
||||
|
||||
tool = client.create_tool(print_tool, name="my_name", tags=["extras"])
|
||||
|
||||
@@ -108,13 +117,15 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
# create agent with tool
|
||||
agent_state = client.create_agent(tools=[tool.name])
|
||||
response = client.user_message(agent_id=agent_state.id, message="hi")
|
||||
|
||||
# Send message without error
|
||||
client.user_message(agent_id=agent_state.id, message="hi")
|
||||
|
||||
|
||||
def test_create_agent_tool(client):
|
||||
"""Test creation of a agent tool"""
|
||||
|
||||
def core_memory_clear(self: Agent):
|
||||
def core_memory_clear(self: "Agent"):
|
||||
"""
|
||||
Args:
|
||||
agent (Agent): The agent to delete from memory.
|
||||
|
||||
Reference in New Issue
Block a user