feat: Add orm for Tools and clean up Tool logic (#1935)

This commit is contained in:
Matthew Zhou
2024-10-25 14:25:40 -07:00
committed by GitHub
parent 150240c7a7
commit d74406af41
37 changed files with 833 additions and 789 deletions

View File

@@ -1,21 +1,13 @@
name: Code Style Checks
on:
push:
branches: [ main ]
pull_request:
paths:
- '**.py'
pull_request_target:
types:
- opened
- edited
- synchronize
workflow_dispatch:
permissions:
pull-requests: read
branches: [ main ]
jobs:
validation-checks:
style-checks:
runs-on: ubuntu-latest
strategy:
matrix:

View File

@@ -1,4 +1,4 @@
name: Run CLI tests
name: Test CLI
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -10,7 +10,7 @@ on:
branches: [ main ]
jobs:
test:
test-cli:
runs-on: ubuntu-latest
timeout-minutes: 15

View File

@@ -1,4 +1,4 @@
name: Run All pytest Tests
name: Unit Tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -13,7 +13,7 @@ on:
branches: [ main ]
jobs:
test:
unit-tests:
runs-on: ubuntu-latest
timeout-minutes: 15
@@ -37,6 +37,28 @@ jobs:
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
- name: Run LocalClient tests
env:
LETTA_PG_PORT: 8888
LETTA_PG_USER: letta
LETTA_PG_PASSWORD: letta
LETTA_PG_DB: letta
LETTA_PG_HOST: localhost
LETTA_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/test_local_client.py
- name: Run RESTClient tests
env:
LETTA_PG_PORT: 8888
LETTA_PG_USER: letta
LETTA_PG_PASSWORD: letta
LETTA_PG_DB: letta
LETTA_PG_HOST: localhost
LETTA_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/test_client.py
- name: Run server tests
env:
LETTA_PG_PORT: 8888
@@ -70,6 +92,17 @@ jobs:
run: |
poetry run pytest -s -vv tests/test_tools.py
- name: Run o1 agent tests
env:
LETTA_PG_PORT: 8888
LETTA_PG_USER: letta
LETTA_PG_PASSWORD: letta
LETTA_PG_DB: letta
LETTA_PG_HOST: localhost
LETTA_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/test_o1_agent.py
- name: Run tests with pytest
env:
LETTA_PG_PORT: 8888
@@ -80,4 +113,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers" tests
poetry run pytest -s -vv -k "not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests

View File

@@ -5,7 +5,6 @@ from letta import create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.tool import Tool
"""
Setup here.
@@ -49,10 +48,7 @@ def main():
from composio_langchain import Action
# Add the composio tool
tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# create tool
client.add_tool(tool)
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
persona = f"""
My name is Letta.

View File

@@ -2,8 +2,9 @@ import json
import uuid
from letta import create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.tool import Tool
"""
This example show how you can add CrewAI tools .
@@ -21,14 +22,14 @@ def main():
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
example_website_scrape_tool = Tool.from_crewai(crewai_tool)
tool_name = example_website_scrape_tool.name
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
# create tool
client.add_tool(example_website_scrape_tool)
example_website_scrape_tool = client.load_crewai_tool(crewai_tool)
tool_name = example_website_scrape_tool.name
# Confirm that the tool is in
tools = client.list_tools()

View File

@@ -5,7 +5,6 @@ from letta import create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.tool import Tool
"""
This example show how you can add LangChain tools .
@@ -26,25 +25,22 @@ def main():
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Translate to memGPT Tool
# Note the additional_imports_module_attr_map
# We need to pass in a map of all the additional imports necessary to run this tool
# Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool,
# We need to also import WikipediaAPIWrapper
# The map is a mapping of the module name to the attribute name
# langchain_community.utilities.WikipediaAPIWrapper
wikipedia_query_tool = Tool.from_langchain(
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
)
tool_name = wikipedia_query_tool.name
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
# create tool
client.add_tool(wikipedia_query_tool)
# Note the additional_imports_module_attr_map
# We need to pass in a map of all the additional imports necessary to run this tool
# Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool,
# We need to also import WikipediaAPIWrapper
# The map is a mapping of the module name to the attribute name
# langchain_community.utilities.WikipediaAPIWrapper
wikipedia_query_tool = client.load_langchain_tool(
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
)
tool_name = wikipedia_query_tool.name
# Confirm that the tool is in
tools = client.list_tools()

View File

@@ -240,7 +240,6 @@ class Agent(BaseAgent):
assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}"
# link tools
self.tools = tools
self.link_tools(tools)
# gpt-4, gpt-3.5-turbo, ...
@@ -1521,10 +1520,6 @@ def save_agent(agent: Agent, ms: MetadataStore):
else:
ms.create_agent(agent_state)
for tool in agent.tools:
if ms.get_tool(tool_name=tool.name, user_id=tool.user_id) is None:
ms.create_tool(tool)
agent.agent_state = ms.get_agent(agent_id=agent_id)
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"

View File

@@ -133,6 +133,7 @@ def run(
# read user id from config
ms = MetadataStore(config)
client = create_client()
server = client.server
# determine agent to use, if not provided
if not yes and not agent:
@@ -217,7 +218,9 @@ def run(
)
# create agent
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
tools = [
server.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=client.user_id) for tool_name in agent_state.tools
]
letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
else: # create new agent
@@ -297,7 +300,7 @@ def run(
)
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
tools = [server.tool_manager.get_tool_by_name_and_user_id(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
letta_agent = Agent(
interface=interface(),

View File

@@ -182,6 +182,15 @@ class AbstractClient(object):
def delete_human(self, id: str):
raise NotImplementedError
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
raise NotImplementedError
def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
raise NotImplementedError
def load_composio_tool(self, action: "ActionType") -> Tool:
raise NotImplementedError
def create_tool(
self,
func,
@@ -1298,12 +1307,6 @@ class RESTClient(AbstractClient):
source_code = parse_source_code(func)
source_type = "python"
# TODO: Check if tool already exists
# if name:
# tool_id = self.get_tool_id(tool_name=name)
# if tool_id:
# raise ValueError(f"Tool with name {name} (id={tool_id}) already exists")
# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
@@ -1337,7 +1340,7 @@ class RESTClient(AbstractClient):
source_type = "python"
request = ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name)
request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update tool: {response.text}")
@@ -1394,6 +1397,7 @@ class RESTClient(AbstractClient):
params["cursor"] = str(cursor)
if limit:
params["limit"] = limit
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list tools: {response.text}")
@@ -1503,6 +1507,7 @@ class LocalClient(AbstractClient):
self,
auto_save: bool = False,
user_id: Optional[str] = None,
org_id: Optional[str] = None,
debug: bool = False,
default_llm_config: Optional[LLMConfig] = None,
default_embedding_config: Optional[EmbeddingConfig] = None,
@@ -1529,15 +1534,19 @@ class LocalClient(AbstractClient):
self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface_factory=lambda: self.interface)
# save org_id that `LocalClient` is associated with
if org_id:
self.org_id = org_id
else:
self.org_id = self.server.organization_manager.DEFAULT_ORG_ID
# save user_id that `LocalClient` is associated with
if user_id:
self.user_id = user_id
else:
# get default user
self.user_id = self.server.get_default_user().id
self.user_id = self.server.user_manager.DEFAULT_USER_ID
# agents
def list_agents(self) -> List[AgentState]:
self.interface.clear()
@@ -2186,50 +2195,27 @@ class LocalClient(AbstractClient):
self.server.delete_block(id)
# tools
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_langchain(
langchain_tool=langchain_tool,
user_id=self.user_id,
organization_id=self.org_id,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(tool_create)
# TODO: merge this into create_tool
def add_tool(self, tool: Tool, update: Optional[bool] = True) -> Tool:
"""
Adds a tool directly.
def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_crewai(
crewai_tool=crewai_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
user_id=self.user_id,
organization_id=self.org_id,
)
return self.server.tool_manager.create_or_update_tool(tool_create)
Args:
tool (Tool): The tool to add.
update (bool, optional): Update the tool if it already exists. Defaults to True.
Returns:
None
"""
if self.tool_with_name_and_user_id_exists(tool):
if update:
return self.server.update_tool(
ToolUpdate(
id=tool.id,
description=tool.description,
source_type=tool.source_type,
source_code=tool.source_code,
tags=tool.tags,
json_schema=tool.json_schema,
name=tool.name,
),
self.user_id,
)
else:
raise ValueError(f"Tool with id={tool.id} and name={tool.name}already exists")
else:
# call server function
return self.server.create_tool(
ToolCreate(
id=tool.id,
description=tool.description,
source_type=tool.source_type,
source_code=tool.source_code,
name=tool.name,
json_schema=tool.json_schema,
tags=tool.tags,
),
user_id=self.user_id,
update=update,
)
def load_composio_tool(self, action: "ActionType") -> Tool:
tool_create = ToolCreate.from_composio(action=action, user_id=self.user_id, organization_id=self.org_id)
return self.server.tool_manager.create_or_update_tool(tool_create)
# TODO: Use the above function `add_tool` here as there is duplicate logic
def create_tool(
@@ -2262,11 +2248,16 @@ class LocalClient(AbstractClient):
tags = []
# call server function
return self.server.create_tool(
# ToolCreate(source_type=source_type, source_code=source_code, name=tool_name, json_schema=json_schema, tags=tags),
ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags, terminal=terminal),
user_id=self.user_id,
update=update,
return self.server.tool_manager.create_or_update_tool(
ToolCreate(
user_id=self.user_id,
organization_id=self.org_id,
source_type=source_type,
source_code=source_code,
name=name,
tags=tags,
terminal=terminal,
),
)
def update_tool(
@@ -2288,16 +2279,17 @@ class LocalClient(AbstractClient):
Returns:
tool (Tool): Updated tool
"""
if func:
source_code = parse_source_code(func)
else:
source_code = None
update_data = {
"source_type": "python", # Always include source_type
"source_code": parse_source_code(func) if func else None,
"tags": tags,
"name": name,
}
source_type = "python"
# Filter out any None values from the dictionary
update_data = {key: value for key, value in update_data.items() if value is not None}
return self.server.update_tool(
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
)
return self.server.tool_manager.update_tool_by_id(id, ToolUpdate(**update_data))
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
@@ -2306,11 +2298,11 @@ class LocalClient(AbstractClient):
Returns:
tools (List[Tool]): List of tools
"""
return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id)
return self.server.tool_manager.list_tools_for_org(cursor=cursor, limit=limit, organization_id=self.org_id)
def get_tool(self, id: str) -> Optional[Tool]:
"""
Get a tool give its ID.
Get a tool given its ID.
Args:
id (str): ID of the tool
@@ -2318,7 +2310,7 @@ class LocalClient(AbstractClient):
Returns:
tool (Tool): Tool
"""
return self.server.get_tool(id)
return self.server.tool_manager.get_tool_by_id(id)
def delete_tool(self, id: str):
"""
@@ -2327,11 +2319,11 @@ class LocalClient(AbstractClient):
Args:
id (str): ID of the tool
"""
return self.server.delete_tool(id)
return self.server.tool_manager.delete_tool_by_id(id)
def get_tool_id(self, name: str) -> Optional[str]:
"""
Get the ID of a tool
Get the ID of a tool from its name. The client will use the org_id it is configured with.
Args:
name (str): Name of the tool
@@ -2339,19 +2331,8 @@ class LocalClient(AbstractClient):
Returns:
id (str): ID of the tool (`None` if not found)
"""
return self.server.get_tool_id(name, self.user_id)
def tool_with_name_and_user_id_exists(self, tool: Tool) -> bool:
"""
Check if the tool with name and user_id exists
Args:
tool (Tool): the tool
Returns:
(bool): True if the id exists, False otherwise.
"""
return self.server.tool_with_name_and_user_id_exists(tool, self.user_id)
tool = self.server.tool_manager.get_tool_by_name_and_org_id(tool_name=name, organization_id=self.org_id)
return tool.id
def load_data(self, connector: DataConnector, source_name: str):
"""

View File

@@ -13,13 +13,13 @@ from letta.constants import (
DEFAULT_HUMAN,
DEFAULT_PERSONA,
DEFAULT_PRESET,
DEFAULT_USER_ID,
LETTA_DIR,
)
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.services.user_manager import UserManager
logger = get_logger(__name__)
@@ -45,7 +45,7 @@ def set_field(config, section, field, value):
@dataclass
class LettaConfig:
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(LETTA_DIR, "config")
anon_clientid: str = DEFAULT_USER_ID
anon_clientid: str = UserManager.DEFAULT_USER_ID
# preset
preset: str = DEFAULT_PRESET # TODO: rename to system prompt

View File

@@ -3,15 +3,6 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
# Defaults
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
# This UUID follows the UUID4 rules:
# The 13th character (4) indicates it's version 4.
# The first character of the third segment (8) ensures the variant is correctly set.
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
DEFAULT_USER_NAME = "default_user"
DEFAULT_ORG_NAME = "default_org"
# String in the error message for when the context window is too large
# Example full message:

View File

@@ -3,9 +3,33 @@ import inspect
import os
from textwrap import dedent # remove indentation
from types import ModuleType
from typing import Optional
from letta.constants import CLI_WARNING_PREFIX
from letta.functions.schema_generator import generate_schema
from letta.schemas.tool import ToolCreate
def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
# auto-generate openai schema
try:
# Define a custom environment with necessary imports
env = {
"Optional": Optional, # Add any other required imports here
}
env.update(globals())
exec(tool_create.source_code, env)
# get available functions
functions = [f for f in env if callable(env[f])]
# TODO: not sure if this always works
func = env[functions[-1]]
json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name)
return json_schema
except Exception as e:
raise RuntimeError(f"Failed to execute source code: {e}")
def parse_source_code(func) -> str:

View File

@@ -1,5 +1,6 @@
from typing import Any, Optional, Union
import humps
from pydantic import BaseModel
@@ -8,7 +9,7 @@ def generate_composio_tool_wrapper(action: "ActionType") -> tuple[str, str]:
tool_instantiation_str = f"composio_toolset.get_tools(actions=[Action.{str(action)}])[0]"
# Generate func name
func_name = f"run_{action.name.lower()}"
func_name = action.name.lower()
wrapper_function_str = f"""
def {func_name}(**kwargs):
@@ -40,7 +41,7 @@ def generate_langchain_tool_wrapper(
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
func_name = humps.decamelize(tool_name)
# Combine all parts into the wrapper function
wrapper_function_str = f"""
@@ -70,7 +71,7 @@ def generate_crewai_tool_wrapper(tool: "CrewAIBaseTool", additional_imports_modu
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
func_name = humps.decamelize(tool_name)
# Combine all parts into the wrapper function
wrapper_function_str = f"""

View File

@@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model):
}
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None):
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict:
# Get the signature of the function
sig = inspect.signature(function)
@@ -139,7 +139,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
def generate_schema_from_args_schema(
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
) -> Dict[str, Any]:
properties = {}
required = []
@@ -163,4 +163,12 @@ def generate_schema_from_args_schema(
"parameters": {"type": "object", "properties": properties, "required": required},
}
# append heartbeat (necessary for triggering another reasoning step after this tool call)
if append_heartbeat:
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
function_call_json["parameters"]["required"].append("request_heartbeat")
return function_call_json

View File

@@ -14,8 +14,6 @@ from sqlalchemy import (
Integer,
String,
TypeDecorator,
asc,
or_,
)
from sqlalchemy.sql import func
@@ -32,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.settings import settings
from letta.utils import enforce_types, get_utc_time, printd
@@ -359,37 +356,6 @@ class BlockModel(Base):
)
class ToolModel(Base):
__tablename__ = "tools"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
user_id = Column(String)
description = Column(String)
source_type = Column(String)
source_code = Column(String)
json_schema = Column(JSON)
module = Column(String)
tags = Column(JSON)
def __repr__(self) -> str:
return f"<Tool(id='{self.id}', name='{self.name}')>"
def to_record(self) -> Tool:
return Tool(
id=self.id,
name=self.name,
user_id=self.user_id,
description=self.description,
source_type=self.source_type,
source_code=self.source_code,
json_schema=self.json_schema,
module=self.module,
tags=self.tags,
)
class JobModel(Base):
__tablename__ = "jobs"
__table_args__ = {"extend_existing": True}
@@ -516,14 +482,6 @@ class MetadataStore:
session.add(BlockModel(**vars(block)))
session.commit()
@enforce_types
def create_tool(self, tool: Tool):
with self.session_maker() as session:
if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None:
raise ValueError(f"Tool with name {tool.name} already exists")
session.add(ToolModel(**vars(tool)))
session.commit()
@enforce_types
def update_agent(self, agent: AgentState):
with self.session_maker() as session:
@@ -556,18 +514,6 @@ class MetadataStore:
session.add(BlockModel(**vars(block)))
session.commit()
@enforce_types
def update_tool(self, tool_id: str, tool: Tool):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool))
session.commit()
@enforce_types
def delete_tool(self, tool_id: str):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
session.commit()
@enforce_types
def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]):
with self.session_maker() as session:
@@ -612,23 +558,6 @@ class MetadataStore:
session.commit()
@enforce_types
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
with self.session_maker() as session:
# Query for public tools or user-specific tools
query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id))
# Apply cursor if provided (assuming cursor is an ID)
if cursor:
query = query.filter(ToolModel.id > cursor)
# Order by ID and apply limit
results = query.order_by(asc(ToolModel.id)).limit(limit).all()
# Convert to records
res = [r.to_record() for r in results]
return res
@enforce_types
def list_agents(self, user_id: str) -> List[AgentState]:
with self.session_maker() as session:
@@ -672,32 +601,6 @@ class MetadataStore:
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_tool(
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
) -> Optional[ToolModel]:
with self.session_maker() as session:
if tool_id:
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
else:
assert tool_name is not None
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
# assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_block(self, block_id: str) -> Optional[Block]:
with self.session_maker() as session:

View File

@@ -10,7 +10,7 @@ from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
def send_thinking_message(self: Agent, message: str) -> Optional[str]:
def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
"""
Sends a thinking message so that the model can reason out loud before responding.
@@ -24,7 +24,7 @@ def send_thinking_message(self: Agent, message: str) -> Optional[str]:
return None
def send_final_message(self: Agent, message: str) -> Optional[str]:
def send_final_message(self: "Agent", message: str) -> Optional[str]:
"""
Sends a final message to the human user after thinking for a while.

View File

@@ -0,0 +1,15 @@
"""__all__ acts as manual import management to avoid collisions and circular imports."""
# from letta.orm.agent import Agent
# from letta.orm.users_agents import UsersAgents
# from letta.orm.blocks_agents import BlocksAgents
# from letta.orm.token import Token
# from letta.orm.source import Source
# from letta.orm.document import Document
# from letta.orm.passage import Passage
# from letta.orm.memory_templates import MemoryTemplate, HumanMemoryTemplate, PersonaMemoryTemplate
# from letta.orm.sources_agents import SourcesAgents
# from letta.orm.tools_agents import ToolsAgents
# from letta.orm.job import Job
# from letta.orm.block import Block
# from letta.orm.message import Message

View File

@@ -55,7 +55,6 @@ class OrganizationMixin(Base):
__abstract__ = True
# Changed _organization_id to store string (still a valid UUID4 string)
_organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id"))
@property
@@ -65,3 +64,19 @@ class OrganizationMixin(Base):
@organization_id.setter
def organization_id(self, value: str) -> None:
_relation_setter(self, "organization", value)
class UserMixin(Base):
"""Mixin for models that belong to a user."""
__abstract__ = True
_user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id"))
@property
def user_id(self) -> str:
return _relation_getter(self, "user")
@user_id.setter
def user_id(self, value: str) -> None:
_relation_setter(self, "user", value)

View File

@@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
if TYPE_CHECKING:
from letta.orm.tool import Tool
from letta.orm.user import User
@@ -19,6 +20,7 @@ class Organization(SqlalchemyBase):
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion

View File

@@ -184,21 +184,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
logger.warning("to_record is deprecated, use to_pydantic instead.")
return self.to_pydantic()
# TODO: Look into this later and maybe add back?
# def _infer_organization(self, db_session: "Session") -> None:
# """🪄 MAGIC ALERT! 🪄
# Because so much of the original API is centered around user scopes,
# this allows us to continue with that scope and then infer the org from the creating user.
#
# IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
# If not do nothing to the object. Mutates in place.
# """
# if self.created_by_id and hasattr(self, "_organization_id"):
# try:
# from letta.orm.user import User # to avoid circular import
#
# created_by = User.read(db_session, self.created_by_id)
# except NoResultFound:
# logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
# return
# self._organization_id = created_by._organization_id
def _infer_organization(self, db_session: "Session") -> None:
"""🪄 MAGIC ALERT! 🪄
Because so much of the original API is centered around user scopes,
this allows us to continue with that scope and then infer the org from the creating user.
IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
If not do nothing to the object. Mutates in place.
"""
if self.created_by_id and hasattr(self, "_organization_id"):
try:
from letta.orm.user import User # to avoid circular import
created_by = User.read(db_session, self.created_by_id)
except NoResultFound:
logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
return
self._organization_id = created_by._organization_id

54
letta/orm/tool.py Normal file
View File

@@ -0,0 +1,54 @@
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
# TODO everything in functions should live in this model
from letta.orm.enums import ToolSourceType
from letta.orm.mixins import OrganizationMixin, UserMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.tool import Tool as PydanticTool
if TYPE_CHECKING:
pass
from letta.orm.organization import Organization
from letta.orm.user import User
class Tool(SqlalchemyBase, OrganizationMixin, UserMixin):
"""Represents an available tool that the LLM can invoke.
NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools
that are always available, and a subset scoped to the organization. Alternatively, we could use the apply_access_predicate to build
more granular permissions.
"""
__tablename__ = "tool"
__pydantic_model__ = PydanticTool
# Add unique constraint on (name, _organization_id)
# An organization should not have multiple tools with the same name
__table_args__ = (
UniqueConstraint("name", "_organization_id", name="uix_name_organization"),
UniqueConstraint("name", "_user_id", name="uix_name_user"),
)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
module: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
)
# TODO: add terminal here eventually
# This was an intentional decision by Sarah
# relationships
# TODO: Possibly add in user in the future
# This will require some more thought and justification to add this in.
user: Mapped["User"] = relationship("User", back_populates="tools", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")

View File

@@ -1,10 +1,15 @@
from typing import TYPE_CHECKING, List
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.user import User as PydanticUser
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.tool import Tool
class User(SqlalchemyBase, OrganizationMixin):
"""User ORM class"""
@@ -16,6 +21,7 @@ class User(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="user", cascade="all, delete-orphan")
# TODO: Add this back later potentially
# agents: Mapped[List["Agent"]] = relationship(

View File

@@ -10,19 +10,13 @@ from letta.functions.helpers import (
from letta.functions.schema_generator import generate_schema_from_args_schema
from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
class BaseTool(LettaBase):
__id_prefix__ = "tool"
# optional fields
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
module: Optional[str] = Field(None, description="The module of the function.")
# optional: user_id (user-specific tools)
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the function.")
class Tool(BaseTool):
"""
@@ -37,8 +31,12 @@ class Tool(BaseTool):
"""
id: str = BaseTool.generate_id_field()
id: str = Field(..., description="The id of the tool.")
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
module: Optional[str] = Field(None, description="The module of the function.")
user_id: str = Field(..., description="The unique identifier of the user associated with the tool.")
organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.")
name: str = Field(..., description="The name of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
@@ -58,14 +56,31 @@ class Tool(BaseTool):
)
)
class ToolCreate(LettaBase):
user_id: str = Field(UserManager.DEFAULT_USER_ID, description="The user that this tool belongs to. Defaults to the default user ID.")
organization_id: str = Field(
OrganizationManager.DEFAULT_ORG_ID,
description="The organization that this tool belongs to. Defaults to the default organization ID.",
)
name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).")
description: Optional[str] = Field(None, description="The description of the tool.")
tags: List[str] = Field([], description="Metadata tags.")
module: Optional[str] = Field(None, description="The source code of the function.")
source_code: str = Field(..., description="The source code of the function.")
source_type: str = Field(..., description="The source type of the function.")
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
@classmethod
def get_composio_tool(
cls,
action: "ActionType",
) -> "Tool":
def from_composio(
cls, action: "ActionType", user_id: str = UserManager.DEFAULT_USER_ID, organization_id: str = OrganizationManager.DEFAULT_ORG_ID
) -> "ToolCreate":
"""
Class method to create an instance of Letta-compatible Composio Tool.
Check https://docs.composio.dev/introduction/intro/overview to look at options for get_composio_tool
Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
This function will error if we find more than one tool, or 0 tools.
@@ -90,14 +105,9 @@ class Tool(BaseTool):
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action)
json_schema = generate_schema_from_args_schema(composio_tool.args_schema, name=wrapper_func_name, description=description)
# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")
return cls(
user_id=user_id,
organization_id=organization_id,
name=wrapper_func_name,
description=description,
source_type=source_type,
@@ -107,7 +117,13 @@ class Tool(BaseTool):
)
@classmethod
def from_langchain(cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool":
def from_langchain(
cls,
langchain_tool: "LangChainBaseTool",
additional_imports_module_attr_map: dict[str, str] = None,
user_id: str = UserManager.DEFAULT_USER_ID,
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
) -> "ToolCreate":
"""
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
@@ -125,14 +141,9 @@ class Tool(BaseTool):
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map)
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)
# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")
return cls(
user_id=user_id,
organization_id=organization_id,
name=wrapper_func_name,
description=description,
source_type=source_type,
@@ -142,7 +153,13 @@ class Tool(BaseTool):
)
@classmethod
def from_crewai(cls, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool":
def from_crewai(
cls,
crewai_tool: "CrewAIBaseTool",
additional_imports_module_attr_map: dict[str, str] = None,
user_id: str = UserManager.DEFAULT_USER_ID,
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
) -> "ToolCreate":
"""
Class method to create an instance of Tool from a crewAI BaseTool object.
@@ -158,14 +175,9 @@ class Tool(BaseTool):
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool, additional_imports_module_attr_map)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")
return cls(
user_id=user_id,
organization_id=organization_id,
name=wrapper_func_name,
description=description,
source_type=source_type,
@@ -175,54 +187,43 @@ class Tool(BaseTool):
)
@classmethod
def load_default_langchain_tools(cls) -> List["Tool"]:
def load_default_langchain_tools(cls) -> List["ToolCreate"]:
# For now, we only support wikipedia tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
wikipedia_tool = Tool.from_langchain(
wikipedia_tool = ToolCreate.from_langchain(
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), {"langchain_community.utilities": "WikipediaAPIWrapper"}
)
return [wikipedia_tool]
@classmethod
def load_default_crewai_tools(cls) -> List["Tool"]:
def load_default_crewai_tools(cls) -> List["ToolCreate"]:
# For now, we only support scrape website tool
from crewai_tools import ScrapeWebsiteTool
web_scrape_tool = Tool.from_crewai(ScrapeWebsiteTool())
web_scrape_tool = ToolCreate.from_crewai(ScrapeWebsiteTool())
return [web_scrape_tool]
@classmethod
def load_default_composio_tools(cls) -> List["Tool"]:
def load_default_composio_tools(cls) -> List["ToolCreate"]:
from composio_langchain import Action
calculator = Tool.get_composio_tool(action=Action.MATHEMATICAL_CALCULATOR)
serp_news = Tool.get_composio_tool(action=Action.SERPAPI_NEWS_SEARCH)
serp_google_search = Tool.get_composio_tool(action=Action.SERPAPI_SEARCH)
serp_google_maps = Tool.get_composio_tool(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH)
calculator = ToolCreate.from_composio(action=Action.MATHEMATICAL_CALCULATOR)
serp_news = ToolCreate.from_composio(action=Action.SERPAPI_NEWS_SEARCH)
serp_google_search = ToolCreate.from_composio(action=Action.SERPAPI_SEARCH)
serp_google_maps = ToolCreate.from_composio(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH)
return [calculator, serp_news, serp_google_search, serp_google_maps]
class ToolCreate(BaseTool):
id: Optional[str] = Field(None, description="The unique identifier of the tool. If this is not provided, it will be autogenerated.")
name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).")
description: Optional[str] = Field(None, description="The description of the tool.")
tags: List[str] = Field([], description="Metadata tags.")
source_code: str = Field(..., description="The source code of the function.")
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
class ToolUpdate(ToolCreate):
id: str = Field(..., description="The unique identifier of the tool.")
class ToolUpdate(LettaBase):
description: Optional[str] = Field(None, description="The description of the tool.")
name: Optional[str] = Field(None, description="The name of the function.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
module: Optional[str] = Field(None, description="The source code of the function.")
source_code: Optional[str] = Field(None, description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
source_type: Optional[str] = Field(None, description="The type of the source code.")

View File

@@ -3,8 +3,8 @@ from typing import Optional
from pydantic import Field
from letta.constants import DEFAULT_ORG_ID
from letta.schemas.letta_base import LettaBase
from letta.services.organization_manager import OrganizationManager
class UserBase(LettaBase):
@@ -22,7 +22,7 @@ class User(UserBase):
"""
id: str = Field(..., description="The id of the user.")
organization_id: Optional[str] = Field(DEFAULT_ORG_ID, description="The organization id of the user")
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
name: str = Field(..., description="The name of the user.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="The update date of the user.")

View File

@@ -31,7 +31,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
Create a new user in the database
"""
try:
user = server.create_user(request)
user = server.user_manager.create_user(request)
except HTTPException:
raise
except Exception as e:

View File

@@ -2,6 +2,7 @@ from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.orm.errors import NoResultFound
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@@ -13,13 +14,12 @@ router = APIRouter(prefix="/tools", tags=["tools"])
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a tool by name
"""
# actor = server.get_user_or_default(user_id=user_id)
server.delete_tool(tool_id=tool_id)
server.tool_manager.delete_tool(tool_id=tool_id)
@router.get("/{tool_id}", response_model=Tool, operation_id="get_tool")
@@ -30,9 +30,7 @@ def get_tool(
"""
Get a tool by ID
"""
# actor = server.get_current_user()
tool = server.get_tool(tool_id=tool_id)
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
@@ -50,26 +48,26 @@ def get_tool_id(
"""
actor = server.get_user_or_default(user_id=user_id)
tool_id = server.get_tool_id(tool_name, user_id=actor.id)
if tool_id is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool_id
try:
tool = server.tool_manager.get_tool_by_name_and_org_id(tool_name=tool_name, organization_id=actor.organization_id)
return tool.id
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} and organization id {actor.organization_id} not found.")
@router.get("/", response_model=List[Tool], operation_id="list_tools")
def list_all_tools(
def list_tools(
cursor: Optional[str] = None,
limit: Optional[int] = 50,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents created by a user
Get a list of all tools available to agents belonging to the org of the user
"""
try:
actor = server.get_user_or_default(user_id=user_id)
return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id)
return server.tool_manager.list_tools_for_org(organization_id=actor.organization_id, cursor=cursor, limit=limit)
except Exception as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
@@ -78,21 +76,21 @@ def list_all_tools(
@router.post("/", response_model=Tool, operation_id="create_tool")
def create_tool(
tool: ToolCreate = Body(...),
update: bool = False,
request: ToolCreate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new tool
"""
# Derive user and org id from actor
actor = server.get_user_or_default(user_id=user_id)
request.organization_id = actor.organization_id
request.user_id = actor.id
return server.create_tool(
request=tool,
# update=update,
update=True,
user_id=actor.id,
# Send request to create the tool
return server.tool_manager.create_or_update_tool(
tool_create=request,
)
@@ -106,6 +104,4 @@ def update_tool(
"""
Update an existing tool
"""
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
# actor = server.get_user_or_default(user_id=user_id)
return server.update_tool(request, user_id)
return server.tool_manager.update_tool_by_id(tool_id, request)

View File

@@ -1,6 +1,4 @@
# inspecting tools
import importlib
import inspect
import os
import traceback
import warnings
@@ -16,7 +14,6 @@ import letta.system as system
from letta.agent import Agent, save_agent
from letta.agent_store.db import attach_base
from letta.agent_store.storage import StorageConnector, TableType
from letta.client.utils import derive_function_name_regex
from letta.credentials import LettaCredentials
from letta.data_sources.connectors import DataConnector, load_data
@@ -30,11 +27,7 @@ from letta.data_sources.connectors import DataConnector, load_data
# Token,
# User,
# )
from letta.functions.functions import (
generate_schema,
load_function_set,
parse_source_code,
)
from letta.functions.functions import generate_schema, parse_source_code
from letta.functions.schema_generator import generate_schema
# TODO use custom interface
@@ -82,10 +75,11 @@ from letta.schemas.memory import (
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User, UserCreate
from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import create_random_username, json_dumps, json_loads
@@ -214,6 +208,7 @@ class SyncServer(Server):
chaining: bool = True,
max_chaining_steps: Optional[bool] = None,
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
init_with_default_org_and_user: bool = True,
# default_interface: AgentInterface = CLIInterface(),
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
# auth_mode: str = "none", # "none, "jwt", "external"
@@ -249,13 +244,19 @@ class SyncServer(Server):
# Managers that interface with data models
self.organization_manager = OrganizationManager()
self.user_manager = UserManager()
self.tool_manager = ToolManager()
# TODO: this should be removed
# add global default tools (for admin)
self.add_default_tools(module_name="base")
# Make default user and org
if init_with_default_org_and_user:
self.default_org = self.organization_manager.create_default_organization()
self.default_user = self.user_manager.create_default_user()
self.add_default_blocks(self.default_user.id)
self.tool_manager.add_default_tools(module_name="base", user_id=self.default_user.id, org_id=self.default_org.id)
if settings.load_default_external_tools:
self.add_default_external_tools()
# If there is a default org/user
# This logic may have to change in the future
if settings.load_default_external_tools:
self.add_default_external_tools(user_id=self.default_user.id, org_id=self.default_org.id)
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
@@ -364,7 +365,7 @@ class SyncServer(Server):
logger.debug(f"Creating an agent object")
tool_objs = []
for name in agent_state.tools:
tool_obj = self.ms.get_tool(tool_name=name, user_id=user_id)
tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=name, user_id=user_id)
if not tool_obj:
logger.exception(f"Tool {name} does not exist for user {user_id}")
raise ValueError(f"Tool {name} does not exist for user {user_id}")
@@ -755,22 +756,6 @@ class SyncServer(Server):
command = command[1:] # strip the prefix
return self._command(user_id=user_id, agent_id=agent_id, command=command)
def create_user(self, request: UserCreate) -> User:
"""Create a new user using a config"""
if not request.name:
# auto-generate a name
request.name = create_random_username()
user = self.user_manager.create_user(request)
logger.debug(f"Created new user from config: {user}")
# add default for the user
# TODO: move to org
assert user.id is not None, f"User id is None: {user}"
self.add_default_blocks(user.id)
self.add_default_tools(module_name="base", user_id=user.id)
return user
def create_agent(
self,
request: CreateAgent,
@@ -816,8 +801,7 @@ class SyncServer(Server):
tool_objs = []
if request.tools:
for tool_name in request.tools:
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
assert tool_obj, f"Tool {tool_name} does not exist"
tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id)
tool_objs.append(tool_obj)
assert request.memory is not None
@@ -832,16 +816,15 @@ class SyncServer(Server):
json_schema = generate_schema(func, terminal=False, name=func_name)
source_type = "python"
tags = ["memory", "memgpt-base"]
tool = self.create_tool(
request=ToolCreate(
tool = self.tool_manager.create_or_update_tool(
ToolCreate(
source_code=source_code,
source_type=source_type,
tags=tags,
json_schema=json_schema,
user_id=user_id,
),
update=True,
user_id=user_id,
organization_id=user.organization_id,
)
)
tool_objs.append(tool)
if not request.tools:
@@ -939,7 +922,7 @@ class SyncServer(Server):
# (1) get tools + make sure they exist
tool_objs = []
for tool_name in request.tools:
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id)
assert tool_obj, f"Tool {tool_name} does not exist"
tool_objs.append(tool_obj)
@@ -995,12 +978,12 @@ class SyncServer(Server):
# Get all the tool objects from the request
tool_objs = []
tool_obj = self.ms.get_tool(tool_id=tool_id, user_id=user_id)
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id)
assert tool_obj, f"Tool with id={tool_id} does not exist"
tool_objs.append(tool_obj)
for tool in letta_agent.tools:
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id)
assert tool_obj, f"Tool with id={tool.id} does not exist"
# If it's not the already added tool
@@ -1035,7 +1018,7 @@ class SyncServer(Server):
# Get all the tool_objs
tool_objs = []
for tool in letta_agent.tools:
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id)
assert tool_obj, f"Tool with id={tool.id} does not exist"
# If it's not the tool we want to remove
@@ -1076,86 +1059,6 @@ class SyncServer(Server):
agents_states = self.ms.list_agents(user_id=user_id)
return agents_states
# TODO make return type pydantic
def list_agents_legacy(
self,
user_id: str,
) -> dict:
"""List all available agents to a user"""
if user_id is None:
agents_states = self.ms.list_all_agents()
else:
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
agents_states = self.ms.list_agents(user_id=user_id)
agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states]
# TODO add a get_message_obj_from_message_id(...) function
# this would allow grabbing Message.created_by without having to load the agent object
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
self.ms.list_tools()
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
# Get the agent object (loaded in memory)
letta_agent = self._get_or_load_agent(user_id=agent_state.user_id, agent_id=agent_state.id)
# TODO remove this eventually when return type get pydanticfied
# this is to add persona_name and human_name so that the columns in UI can populate
# TODO hack for frontend, remove
# (top level .persona is persona_name, and nested memory.persona is the state)
# TODO: eventually modify this to be contained in the metadata
return_dict["persona"] = agent_state._metadata.get("persona", None)
return_dict["human"] = agent_state._metadata.get("human", None)
# Add information about tools
# TODO letta_agent should really have a field of List[ToolModel]
# then we could just pull that field and return it here
# return_dict["tools"] = [tool for tool in all_available_tools if tool.json_schema in letta_agent.functions]
# get tool info from agent state
tools = []
for tool_name in agent_state.tools:
tool = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
tools.append(tool)
return_dict["tools"] = tools
# Add information about memory (raw core, size of recall, size of archival)
core_memory = letta_agent.memory
recall_memory = letta_agent.persistence_manager.recall_memory
archival_memory = letta_agent.persistence_manager.archival_memory
memory_obj = {
"core_memory": core_memory.to_flat_dict(),
"recall_memory": len(recall_memory) if recall_memory is not None else None,
"archival_memory": len(archival_memory) if archival_memory is not None else None,
}
return_dict["memory"] = memory_obj
# Add information about last run
# NOTE: 'last_run' is just the timestamp on the latest message in the buffer
# Retrieve the Message object via the recall storage or by directly access _messages
last_msg_obj = letta_agent._messages[-1]
return_dict["last_run"] = last_msg_obj.created_at
# Add information about attached sources
sources_ids = self.ms.list_attached_sources(agent_id=agent_state.id)
sources = [self.ms.get_source(source_id=s_id) for s_id in sources_ids]
return_dict["sources"] = [vars(s) for s in sources]
# Sort agents by "last_run" in descending order, most recent first
agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True)
logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}")
return {
"num_agents": len(agents_states),
"agents": agents_states_dicts,
}
# blocks
def get_blocks(
self,
user_id: Optional[str] = None,
@@ -1830,195 +1733,17 @@ class SyncServer(Server):
return sources_with_metadata
def get_tool(self, tool_id: str) -> Optional[Tool]:
"""Get tool by ID."""
return self.ms.get_tool(tool_id=tool_id)
def tool_with_name_and_user_id_exists(self, tool: Tool, user_id: Optional[str] = None) -> bool:
"""Check if tool exists"""
tool = self.ms.get_tool_with_name_and_user_id(tool_name=tool.name, user_id=user_id)
if tool is None:
return False
else:
return True
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
"""Get tool ID from name and user_id."""
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
if not tool or tool.id is None:
return None
return tool.id
def update_tool(self, request: ToolUpdate, user_id: Optional[str] = None) -> Tool:
"""Update an existing tool"""
if request.name:
existing_tool = self.ms.get_tool_with_name_and_user_id(tool_name=request.name, user_id=user_id)
if existing_tool is None:
raise ValueError(f"Tool with name={request.name}, user_id={user_id} does not exist")
else:
existing_tool = self.ms.get_tool(tool_id=request.id)
if existing_tool is None:
raise ValueError(f"Tool with id={request.id} does not exist")
# Preserve the original tool id
# As we can override the tool id as well
# This is probably bad design if this is exposed to users...
original_id = existing_tool.id
# override updated fields
if request.id:
existing_tool.id = request.id
if request.description:
existing_tool.description = request.description
if request.source_code:
existing_tool.source_code = request.source_code
if request.source_type:
existing_tool.source_type = request.source_type
if request.tags:
existing_tool.tags = request.tags
if request.json_schema:
existing_tool.json_schema = request.json_schema
# If name is explicitly provided here, overide the tool name
if request.name:
existing_tool.name = request.name
# Otherwise, if there's no name, and there's source code, we try to derive the name
elif request.source_code:
existing_tool.name = derive_function_name_regex(request.source_code)
self.ms.update_tool(original_id, existing_tool)
return self.ms.get_tool(tool_id=request.id)
def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields
"""Create a new tool"""
# NOTE: deprecated code that existed when we were trying to pretend that `self` was the memory object
# if request.tags and "memory" in request.tags:
# # special modifications to memory functions
# # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
# request.source_code = request.source_code.replace("self.memory", "self.memory.memory")
if not request.json_schema:
# auto-generate openai schema
try:
env = {}
env.update(globals())
exec(request.source_code, env)
# get available functions
functions = [f for f in env if callable(env[f])]
except Exception as e:
logger.error(f"Failed to execute source code: {e}")
# TODO: not sure if this always works
func = env[functions[-1]]
json_schema = generate_schema(func, terminal=request.terminal)
else:
# provided by client
json_schema = request.json_schema
if not request.name:
# use name from JSON schema
request.name = json_schema["name"]
assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen."
# check if already exists:
existing_tool = self.ms.get_tool(tool_id=request.id, tool_name=request.name, user_id=user_id)
if existing_tool:
if update:
# id is an optional field, so we will fill it with the existing tool id
if not request.id:
request.id = existing_tool.id
updated_tool = self.update_tool(ToolUpdate(**vars(request)), user_id)
assert updated_tool is not None, f"Failed to update tool {request.name}"
return updated_tool
else:
raise ValueError(f"Tool {request.name} already exists and update=False")
# check for description
description = None
if request.description:
description = request.description
tool = Tool(
name=request.name,
source_code=request.source_code,
source_type=request.source_type,
tags=request.tags,
json_schema=json_schema,
user_id=user_id,
description=description,
)
if request.id:
tool.id = request.id
self.ms.create_tool(tool)
created_tool = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
return created_tool
def delete_tool(self, tool_id: str):
"""Delete a tool"""
self.ms.delete_tool(tool_id)
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]:
"""List tools available to user_id"""
tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id)
return tools
def add_default_tools(self, module_name="base", user_id: Optional[str] = None):
"""Add default tools in {module_name}.py"""
full_module_name = f"letta.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
raise e
functions_to_schema = []
try:
# Load the function set
functions_to_schema = load_function_set(module)
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
warnings.warn(err)
# create tool in db
for name, schema in functions_to_schema.items():
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("letta-base")
# create to tool
self.create_tool(
ToolCreate(
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
user_id=user_id,
),
update=True,
)
def add_default_external_tools(self, user_id: Optional[str] = None) -> bool:
def add_default_external_tools(self, user_id: str, org_id: str) -> bool:
"""Add default langchain tools. Return true if successful, false otherwise."""
success = True
tool_creates = ToolCreate.load_default_langchain_tools() + ToolCreate.load_default_crewai_tools()
if tool_settings.composio_api_key:
tools = Tool.load_default_langchain_tools() + Tool.load_default_crewai_tools() + Tool.load_default_composio_tools()
else:
tools = Tool.load_default_langchain_tools() + Tool.load_default_crewai_tools()
for tool in tools:
tool_creates += ToolCreate.load_default_composio_tools()
for tool_create in tool_creates:
try:
self.ms.create_tool(tool)
self.tool_manager.create_or_update_tool(tool_create)
except Exception as e:
warnings.warn(f"An error occurred while creating tool {tool}: {e}")
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
warnings.warn(traceback.format_exc())
success = False
@@ -2108,25 +1833,15 @@ class SyncServer(Server):
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.retry_message()
# TODO: Move a lot of this default logic to the ORM
def get_default_user(self) -> User:
self.organization_manager.create_default_organization()
user = self.user_manager.create_default_user()
self.add_default_blocks(user.id)
self.add_default_tools(module_name="base", user_id=user.id)
return user
def get_user_or_default(self, user_id: Optional[str]) -> User:
"""Get the user object for user_id if it exists, otherwise return the default user object"""
if user_id is None:
return self.get_default_user()
else:
try:
return self.user_manager.get_user_by_id(user_id=user_id)
except ValueError:
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
user_id = self.user_manager.DEFAULT_USER_ID
try:
return self.user_manager.get_user_by_id(user_id=user_id)
except ValueError:
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
def list_llm_models(self) -> List[LLMConfig]:
"""List available models"""

View File

@@ -1,8 +1,7 @@
from typing import List, Optional
from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization
from letta.orm.organization import Organization as OrganizationModel
from letta.schemas.organization import Organization as PydanticOrganization
from letta.utils import create_random_username, enforce_types
@@ -10,6 +9,9 @@ from letta.utils import create_random_username, enforce_types
class OrganizationManager:
"""Manager class to handle business logic related to Organizations."""
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
DEFAULT_ORG_NAME = "default_org"
def __init__(self):
# This is probably horrible but we reuse this technique from metadata.py
# TODO: Please refactor this out
@@ -19,12 +21,17 @@ class OrganizationManager:
self.session_maker = db_context
@enforce_types
def get_default_organization(self) -> PydanticOrganization:
"""Fetch the default organization."""
return self.get_organization_by_id(self.DEFAULT_ORG_ID)
@enforce_types
def get_organization_by_id(self, org_id: str) -> PydanticOrganization:
"""Fetch an organization by ID."""
with self.session_maker() as session:
try:
organization = Organization.read(db_session=session, identifier=org_id)
organization = OrganizationModel.read(db_session=session, identifier=org_id)
return organization.to_pydantic()
except NoResultFound:
raise ValueError(f"Organization with id {org_id} not found.")
@@ -33,7 +40,7 @@ class OrganizationManager:
def create_organization(self, name: Optional[str] = None) -> PydanticOrganization:
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
with self.session_maker() as session:
org = Organization(name=name if name else create_random_username())
org = OrganizationModel(name=name if name else create_random_username())
org.create(session)
return org.to_pydantic()
@@ -43,10 +50,10 @@ class OrganizationManager:
with self.session_maker() as session:
# Try to get it first
try:
org = Organization.read(db_session=session, identifier=DEFAULT_ORG_ID)
org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID)
# If it doesn't exist, make it
except NoResultFound:
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
org.create(session)
return org.to_pydantic()
@@ -55,22 +62,22 @@ class OrganizationManager:
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
"""Update an organization."""
with self.session_maker() as session:
organization = Organization.read(db_session=session, identifier=org_id)
org = OrganizationModel.read(db_session=session, identifier=org_id)
if name:
organization.name = name
organization.update(session)
return organization.to_pydantic()
org.name = name
org.update(session)
return org.to_pydantic()
@enforce_types
def delete_organization_by_id(self, org_id: str):
"""Delete an organization by marking it as deleted."""
with self.session_maker() as session:
organization = Organization.read(db_session=session, identifier=org_id)
organization = OrganizationModel.read(db_session=session, identifier=org_id)
organization.delete(session)
@enforce_types
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
"""List organizations with pagination based on cursor (org_id) and limit."""
with self.session_maker() as session:
results = Organization.list(db_session=session, cursor=cursor, limit=limit)
results = OrganizationModel.list(db_session=session, cursor=cursor, limit=limit)
return [org.to_pydantic() for org in results]

View File

@@ -0,0 +1,193 @@
import importlib
import inspect
import warnings
from typing import List, Optional
from letta.functions.functions import derive_openai_json_schema, load_function_set
# TODO: Remove this once we translate all of these to the ORM
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.tool import Tool as ToolModel
from letta.orm.user import User as UserModel
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.utils import enforce_types
class ToolManager:
"""Manager class to handle business logic related to Tools."""
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def create_or_update_tool(self, tool_create: ToolCreate) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Derive json_schema
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create)
derived_name = tool_create.name or derived_json_schema["name"]
try:
# NOTE: We use the organization id here
# This is important, because even if it's a different user, adding the same tool to the org should not happen
tool = self.get_tool_by_name_and_org_id(tool_name=derived_name, organization_id=tool_create.organization_id)
# Put to dict and remove fields that should not be reset
update_data = tool_create.model_dump(exclude={"user_id", "organization_id", "module", "terminal"}, exclude_unset=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
# If there's anything to update
if update_data:
self.update_tool_by_id(tool.id, ToolUpdate(**update_data))
else:
warnings.warn(
f"`create_or_update_tool` was called with user_id={tool_create.user_id}, organization_id={tool_create.organization_id}, name={tool_create.name}, but found existing tool with nothing to update."
)
except NoResultFound:
tool_create.json_schema = derived_json_schema
tool_create.name = derived_name
tool = self.create_tool(tool_create)
return tool
@enforce_types
def create_tool(self, tool_create: ToolCreate) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Create the tool
with self.session_maker() as session:
# Include all fields except 'terminal' (which is not part of ToolModel) at the moment
create_data = tool_create.model_dump(exclude={"terminal"})
tool = ToolModel(**create_data) # Unpack everything directly into ToolModel
tool.create(session)
return tool.to_pydantic()
@enforce_types
def get_tool_by_id(self, tool_id: str) -> PydanticTool:
"""Fetch a tool by its ID."""
with self.session_maker() as session:
try:
# Retrieve tool by id using the Tool model's read method
tool = ToolModel.read(db_session=session, identifier=tool_id)
# Convert the SQLAlchemy Tool object to PydanticTool
return tool.to_pydantic()
except NoResultFound:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
def get_tool_by_name_and_user_id(self, tool_name: str, user_id: str) -> PydanticTool:
"""Retrieve a tool by its name and organization_id."""
with self.session_maker() as session:
# Use the list method to apply filters
results = ToolModel.list(db_session=session, name=tool_name, _user_id=UserModel.get_uid_from_identifier(user_id))
# Ensure only one result is returned (since there is a unique constraint)
if not results:
raise NoResultFound(f"Tool with name {tool_name} and user_id {user_id} not found.")
if len(results) > 1:
raise RuntimeError(
f"Multiple tools with name {tool_name} and user_id {user_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error."
)
# Return the single result
return results[0]
@enforce_types
def get_tool_by_name_and_org_id(self, tool_name: str, organization_id: str) -> PydanticTool:
"""Retrieve a tool by its name and organization_id."""
with self.session_maker() as session:
# Use the list method to apply filters
results = ToolModel.list(
db_session=session, name=tool_name, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id)
)
# Ensure only one result is returned (since there is a unique constraint)
if not results:
raise NoResultFound(f"Tool with name {tool_name} and organization_id {organization_id} not found.")
if len(results) > 1:
raise RuntimeError(
f"Multiple tools with name {tool_name} and organization_id {organization_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error."
)
# Return the single result
return results[0]
@enforce_types
def list_tools_for_org(self, organization_id: str, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
"""List all tools with optional pagination using cursor and limit."""
with self.session_maker() as session:
tools = ToolModel.list(
db_session=session, cursor=cursor, limit=limit, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id)
)
return [tool.to_pydantic() for tool in tools]
@enforce_types
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate) -> None:
"""Update a tool by its ID with the given ToolUpdate object."""
with self.session_maker() as session:
# Fetch the tool by ID
tool = ToolModel.read(db_session=session, identifier=tool_id)
# Update tool attributes with only the fields that were explicitly set
update_data = tool_update.model_dump(exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(tool, key, value)
# Save the updated tool to the database
tool.update(db_session=session)
@enforce_types
def delete_tool_by_id(self, tool_id: str) -> None:
"""Delete a tool by its ID."""
with self.session_maker() as session:
try:
tool = ToolModel.read(db_session=session, identifier=tool_id)
tool.delete(db_session=session)
except NoResultFound:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
def add_default_tools(self, user_id: str, org_id: str, module_name="base"):
"""Add default tools in {module_name}.py"""
full_module_name = f"letta.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
raise e
functions_to_schema = []
try:
# Load the function set
functions_to_schema = load_function_set(module)
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
warnings.warn(err)
# create tool in db
for name, schema in functions_to_schema.items():
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("letta-base")
# create to tool
self.create_or_update_tool(
ToolCreate(
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
organization_id=org_id,
user_id=user_id,
),
)

View File

@@ -1,20 +1,20 @@
from typing import List, Optional, Tuple
from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME
# TODO: Remove this once we translate all of these to the ORM
from letta.metadata import AgentModel, AgentSourceMappingModel, SourceModel
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.user import User as UserModel
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserCreate, UserUpdate
from letta.services.organization_manager import OrganizationManager
from letta.utils import enforce_types
class UserManager:
"""Manager class to handle business logic related to Users."""
DEFAULT_USER_NAME = "default_user"
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.server import db_context
@@ -22,7 +22,7 @@ class UserManager:
self.session_maker = db_context
@enforce_types
def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
with self.session_maker() as session:
# Make sure the org id exists
@@ -33,10 +33,10 @@ class UserManager:
# Try to retrieve the user
try:
user = UserModel.read(db_session=session, identifier=DEFAULT_USER_ID)
user = UserModel.read(db_session=session, identifier=self.DEFAULT_USER_ID)
except NoResultFound:
# If it doesn't exist, make it
user = UserModel(id=DEFAULT_USER_ID, name=DEFAULT_USER_NAME, organization_id=org_id)
user = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
user.create(session)
return user.to_pydantic()
@@ -73,11 +73,11 @@ class UserManager:
user = UserModel.read(db_session=session, identifier=user_id)
user.delete(session)
# TODO: Remove this once we have ORM models for the Agent, Source, and AgentSourceMapping
# TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping
# Cascade delete for related models: Agent, Source, AgentSourceMapping
session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
# session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
# session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
# session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
session.commit()
@@ -91,6 +91,11 @@ class UserManager:
except NoResultFound:
raise ValueError(f"User with id {user_id} not found.")
@enforce_types
def get_default_user(self) -> PydanticUser:
"""Fetch the default user."""
return self.get_user_by_id(self.DEFAULT_USER_ID)
@enforce_types
def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]:
"""List users with pagination using cursor (id) and limit."""

View File

@@ -165,16 +165,13 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
"""
from crewai_tools import ScrapeWebsiteTool
from letta.schemas.tool import Tool
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
tool = Tool.from_crewai(crewai_tool)
tool_name = tool.name
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
client.add_tool(tool)
tool = client.load_crewai_tool(crewai_tool=crewai_tool)
tool_name = tool.name
# Set up persona for tool usage
persona = f"""

View File

@@ -41,17 +41,17 @@ def test_letta_run_create_new_agent(swap_letta_config):
child = pexpect.spawn("poetry run letta run", encoding="utf-8")
# Start the letta run command
child.logfile = sys.stdout
child.expect("Creating new agent", timeout=10)
child.expect("Creating new agent", timeout=20)
# Optional: LLM model selection
try:
child.expect("Select LLM model:", timeout=10)
child.sendline("\033[B\033[B\033[B\033[B\033[B")
child.expect("Select LLM model:", timeout=20)
child.sendline("")
except (pexpect.TIMEOUT, pexpect.EOF):
print("[WARNING] LLM model selection step was skipped.")
# Optional: Embedding model selection
try:
child.expect("Select embedding model:", timeout=10)
child.expect("Select embedding model:", timeout=20)
child.sendline("text-embedding-ada-002")
except (pexpect.TIMEOUT, pexpect.EOF):
print("[WARNING] Embedding model selection step was skipped.")

View File

@@ -262,15 +262,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
assert human.value == "Human text", "Creating human failed"
# def test_tools(client, agent):
# tools_response = client.list_tools()
# print("TOOLS", tools_response)
#
# tool_name = "TestTool"
# tool_response = client.create_tool(name=tool_name, source_code="print('Hello World')", source_type="python")
# assert tool_response, "Creating tool failed"
def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()

View File

@@ -1,16 +1,15 @@
import uuid
from typing import Union
import pytest
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.client.client import LocalClient
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
from letta.schemas.tool import Tool
from letta.schemas.tool import ToolCreate
@pytest.fixture(scope="module")
@@ -35,7 +34,7 @@ def agent(client):
assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}"
def test_agent(client: Union[LocalClient, RESTClient]):
def test_agent(client: LocalClient):
# create agent
agent_state_test = client.create_agent(
name="test_agent2",
@@ -120,18 +119,16 @@ def test_agent(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_state_test.id)
def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
def test_agent_add_remove_tools(client: LocalClient, agent):
# Create and add two tools to the client
# tool 1
from composio_langchain import Action
github_tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
client.add_tool(github_tool)
github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# tool 2
from crewai_tools import ScrapeWebsiteTool
scrape_website_tool = Tool.from_crewai(ScrapeWebsiteTool(website_url="https://www.example.com"))
client.add_tool(scrape_website_tool)
scrape_website_tool = client.load_crewai_tool(crewai_tool=ScrapeWebsiteTool(website_url="https://www.example.com"))
# assert both got added
tools = client.list_tools()
@@ -171,7 +168,7 @@ def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
assert scrape_website_tool.name in curr_tool_names
def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
def test_agent_with_shared_blocks(client: LocalClient):
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
existing_non_template_blocks = [persona_block, human_block]
@@ -222,7 +219,7 @@ def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
client.delete_agent(second_agent_state_test.id)
def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_memory(client: LocalClient, agent: AgentState):
# get agent memory
original_memory = client.get_in_context_memory(agent.id)
assert original_memory is not None
@@ -235,7 +232,7 @@ def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_archival_memory(client: LocalClient, agent: AgentState):
"""Test functions for interacting with archival memory store"""
# add archival memory
@@ -250,7 +247,7 @@ def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentSta
client.delete_archival_memory(agent.id, passage.id)
def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_recall_memory(client: LocalClient, agent: AgentState):
"""Test functions for interacting with recall memory store"""
# send message to the agent
@@ -274,7 +271,7 @@ def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState
assert exists
def test_tools(client: Union[LocalClient, RESTClient]):
def test_tools(client: LocalClient):
def print_tool(message: str):
"""
A tool to print a message
@@ -314,20 +311,18 @@ def test_tools(client: Union[LocalClient, RESTClient]):
assert client.get_tool(tool.id).tags == extras2
# update tool: source code
client.update_tool(tool.id, func=print_tool2)
client.update_tool(tool.id, name="print_tool2", func=print_tool2)
assert client.get_tool(tool.id).name == "print_tool2"
def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
def test_tools_from_composio_basic(client: LocalClient):
from composio_langchain import Action
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# create tool
client.add_tool(tool)
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# list tools
tools = client.list_tools()
@@ -337,18 +332,15 @@ def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
def test_tools_from_crewai(client: LocalClient):
# create crewAI tool
from crewai_tools import ScrapeWebsiteTool
crewai_tool = ScrapeWebsiteTool()
# Translate to memGPT Tool
tool = Tool.from_crewai(crewai_tool)
# Add the tool
client.add_tool(tool)
tool = client.load_crewai_tool(crewai_tool=crewai_tool)
# list tools
tools = client.list_tools()
@@ -372,18 +364,15 @@ def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
assert expected_content in func(website_url=simple_webpage_url)
def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
def test_tools_from_crewai_with_params(client: LocalClient):
# create crewAI tool
from crewai_tools import ScrapeWebsiteTool
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
# Translate to memGPT Tool
tool = Tool.from_crewai(crewai_tool)
# Add the tool
client.add_tool(tool)
tool = client.load_crewai_tool(crewai_tool=crewai_tool)
# list tools
tools = client.list_tools()
@@ -404,7 +393,7 @@ def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
assert expected_content in func()
def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
def test_tools_from_langchain(client: LocalClient):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
@@ -412,11 +401,10 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Translate to memGPT Tool
tool = Tool.from_langchain(langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"})
# Add the tool
client.add_tool(tool)
tool = client.load_langchain_tool(
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
)
# list tools
tools = client.list_tools()
@@ -436,7 +424,7 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
assert expected_content in func(query="Albert Einstein")
def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]):
def test_tool_creation_langchain_missing_imports(client: LocalClient):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
@@ -447,4 +435,4 @@ def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, REST
# Translate to memGPT Tool
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
with pytest.raises(RuntimeError):
Tool.from_langchain(langchain_tool)
ToolCreate.from_langchain(langchain_tool)

View File

@@ -2,14 +2,12 @@ import pytest
from sqlalchemy import delete
import letta.utils as utils
from letta.constants import (
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
DEFAULT_USER_ID,
DEFAULT_USER_NAME,
)
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.orm.organization import Organization
from letta.orm.tool import Tool
from letta.orm.user import User
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.services.organization_manager import OrganizationManager
utils.DEBUG = True
from letta.config import LettaConfig
@@ -18,21 +16,58 @@ from letta.server.server import SyncServer
@pytest.fixture(autouse=True)
def clear_organization_and_user_table(server: SyncServer):
def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Tool)) # Clear all records from the Tool table
session.execute(delete(User)) # Clear all records from the user table
session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion
@pytest.fixture
def tool_fixture(server: SyncServer):
"""Fixture to create a tool with default settings and clean up after the test."""
def print_tool(message: str):
"""
Args:
message (str): The message to print.
Returns:
str: The message that was printed.
"""
print(message)
return message
source_code = parse_source_code(print_tool)
source_type = "python"
description = "test_description"
tags = ["test"]
org = server.organization_manager.create_default_organization()
user = server.user_manager.create_default_user()
tool_create = ToolCreate(
user_id=user.id, organization_id=org.id, description=description, tags=tags, source_code=source_code, source_type=source_type
)
derived_json_schema = derive_openai_json_schema(tool_create)
derived_name = derived_json_schema["name"]
tool_create.json_schema = derived_json_schema
tool_create.name = derived_name
tool = server.tool_manager.create_tool(tool_create)
# Yield the created tool, organization, and user for use in tests
yield {"tool": tool, "organization": org, "user": user, "tool_create": tool_create}
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
config.save()
server = SyncServer()
server = SyncServer(init_with_default_org_and_user=False)
return server
@@ -55,8 +90,8 @@ def test_list_organizations(server: SyncServer):
def test_create_default_organization(server: SyncServer):
server.organization_manager.create_default_organization()
retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
assert retrieved.name == DEFAULT_ORG_NAME
retrieved = server.organization_manager.get_default_organization()
assert retrieved.name == server.organization_manager.DEFAULT_ORG_NAME
def test_update_organization_name(server: SyncServer):
@@ -105,8 +140,8 @@ def test_list_users(server: SyncServer):
def test_create_default_user(server: SyncServer):
org = server.organization_manager.create_default_organization()
server.user_manager.create_default_user(org_id=org.id)
retrieved = server.user_manager.get_user_by_id(DEFAULT_USER_ID)
assert retrieved.name == DEFAULT_USER_NAME
retrieved = server.user_manager.get_default_user()
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME
def test_update_user(server: SyncServer):
@@ -124,9 +159,117 @@ def test_update_user(server: SyncServer):
# Adjust name
user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b))
assert user.name == user_name_b
assert user.organization_id == DEFAULT_ORG_ID
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
# Adjust org id
user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id))
assert user.name == user_name_b
assert user.organization_id == test_org.id
# ======================================================================================================================
# Tool Manager Tests
# ======================================================================================================================
def test_create_tool(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
tool_create = tool_fixture["tool_create"]
user = tool_fixture["user"]
org = tool_fixture["organization"]
# Assertions to ensure the created tool matches the expected values
assert tool.user_id == user.id
assert tool.organization_id == org.id
assert tool.description == tool_create.description
assert tool.tags == tool_create.tags
assert tool.source_code == tool_create.source_code
assert tool.source_type == tool_create.source_type
assert tool.json_schema == derive_openai_json_schema(tool_create)
def test_get_tool_by_id(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
# Fetch the tool by ID using the manager method
fetched_tool = server.tool_manager.get_tool_by_id(tool.id)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
assert fetched_tool.name == tool.name
assert fetched_tool.description == tool.description
assert fetched_tool.tags == tool.tags
assert fetched_tool.source_code == tool.source_code
assert fetched_tool.source_type == tool.source_type
def test_get_tool_by_name_and_org_id(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
org = tool_fixture["organization"]
# Fetch the tool by name and organization ID
fetched_tool = server.tool_manager.get_tool_by_name_and_org_id(tool.name, org.id)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
assert fetched_tool.name == tool.name
assert fetched_tool.organization_id == org.id
assert fetched_tool.description == tool.description
assert fetched_tool.tags == tool.tags
assert fetched_tool.source_code == tool.source_code
assert fetched_tool.source_type == tool.source_type
def test_get_tool_by_name_and_user_id(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
# Fetch the tool by name and organization ID
fetched_tool = server.tool_manager.get_tool_by_name_and_user_id(tool.name, user.id)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
assert fetched_tool.name == tool.name
assert fetched_tool.user_id == user.id
assert fetched_tool.description == tool.description
assert fetched_tool.tags == tool.tags
assert fetched_tool.source_code == tool.source_code
assert fetched_tool.source_type == tool.source_type
def test_list_tools(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
org = tool_fixture["organization"]
# List tools (should include the one created by the fixture)
tools = server.tool_manager.list_tools_for_org(organization_id=org.id)
# Assertions to check that the created tool is listed
assert len(tools) == 1
assert any(t.id == tool.id for t in tools)
def test_update_tool_by_id(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
updated_description = "updated_description"
# Create a ToolUpdate object to modify the tool's description
tool_update = ToolUpdate(description=updated_description)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id)
# Assertions to check if the update was successful
assert updated_tool.description == updated_description
def test_delete_tool_by_id(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
org = tool_fixture["organization"]
# Delete the tool using the manager method
server.tool_manager.delete_tool_by_id(tool.id)
tools = server.tool_manager.list_tools_for_org(organization_id=org.id)
assert len(tools) == 0

View File

@@ -25,7 +25,6 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.message import Message
from letta.schemas.source import SourceCreate
from letta.schemas.user import UserCreate
from letta.server.server import SyncServer
from .utils import DummyDataConnector
@@ -33,26 +32,9 @@ from .utils import DummyDataConnector
@pytest.fixture(scope="module")
def server():
# if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# credentials = LettaCredentials(
# openai_key=os.getenv("OPENAI_API_KEY"),
# )
# else: # hosted
# create_config("letta_hosted")
# credentials = LettaCredentials()
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
## set to use postgres
# config.archival_storage_uri = db_url
# config.recall_storage_uri = db_url
# config.metadata_storage_uri = db_url
# config.archival_storage_type = "postgres"
# config.recall_storage_type = "postgres"
# config.metadata_storage_type = "postgres"
config.save()
server = SyncServer()
@@ -62,7 +44,7 @@ def server():
@pytest.fixture(scope="module")
def org_id(server):
# create org
org = server.organization_manager.create_organization(name="test_org")
org = server.organization_manager.create_default_organization()
print(f"Created org\n{org.id}")
yield org.id
@@ -74,7 +56,7 @@ def org_id(server):
@pytest.fixture(scope="module")
def user_id(server, org_id):
# create user
user = server.create_user(UserCreate(name="test_user", organization_id=org_id))
user = server.user_manager.create_default_user()
print(f"Created user\n{user.id}")
yield user.id

View File

@@ -56,13 +56,13 @@ def client(request):
time.sleep(5)
print("Running client tests with server:", server_url)
else:
server_url = None
assert False, "Local client not implemented"
assert server_url is not None
client = create_client(base_url=server_url) # This yields control back to the test function
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
# Clear all records from the Tool table
yield client
@@ -93,7 +93,16 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
return message
tools = client.list_tools()
print(f"Original tools {[t.name for t in tools]}")
assert sorted([t.name for t in tools]) == sorted(
[
"archival_memory_search",
"send_message",
"pause_heartbeats",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",
]
)
tool = client.create_tool(print_tool, name="my_name", tags=["extras"])
@@ -108,13 +117,15 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
# create agent with tool
agent_state = client.create_agent(tools=[tool.name])
response = client.user_message(agent_id=agent_state.id, message="hi")
# Send message without error
client.user_message(agent_id=agent_state.id, message="hi")
def test_create_agent_tool(client):
"""Test creation of a agent tool"""
def core_memory_clear(self: Agent):
def core_memory_clear(self: "Agent"):
"""
Args:
agent (Agent): The agent to delete from memory.