chore: Change create_tool endpoint on v1 routes to error instead of upsert (#2102)
This commit is contained in:
104
.github/workflows/test_groq.yml
vendored
104
.github/workflows/test_groq.yml
vendored
@@ -1,104 +0,0 @@
|
||||
name: Groq Llama 3.1 70b Capabilities Test
|
||||
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E external-tools"
|
||||
|
||||
- name: Test first message contains expected function call and inner monologue
|
||||
id: test_first_message
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_returns_valid_first_message
|
||||
echo "TEST_FIRST_MESSAGE_EXIT_CODE=$?" >> $GITHUB_ENV
|
||||
continue-on-error: true
|
||||
|
||||
- name: Test model sends message with keyword
|
||||
id: test_keyword_message
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_returns_keyword
|
||||
echo "TEST_KEYWORD_MESSAGE_EXIT_CODE=$?" >> $GITHUB_ENV
|
||||
continue-on-error: true
|
||||
|
||||
- name: Test model uses external tool correctly
|
||||
id: test_external_tool
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_uses_external_tool
|
||||
echo "TEST_EXTERNAL_TOOL_EXIT_CODE=$?" >> $GITHUB_ENV
|
||||
continue-on-error: true
|
||||
|
||||
- name: Test model recalls chat memory
|
||||
id: test_chat_memory
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_recall_chat_memory
|
||||
echo "TEST_CHAT_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
|
||||
continue-on-error: true
|
||||
|
||||
- name: Test model uses 'archival_memory_search' to find secret
|
||||
id: test_archival_memory
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_archival_memory_retrieval
|
||||
echo "TEST_ARCHIVAL_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
|
||||
continue-on-error: true
|
||||
|
||||
- name: Test model can edit core memories
|
||||
id: test_core_memory
|
||||
env:
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_edit_core_memory
|
||||
echo "TEST_CORE_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
|
||||
continue-on-error: true
|
||||
|
||||
- name: Summarize test results
|
||||
if: always()
|
||||
run: |
|
||||
echo "Test Results Summary:"
|
||||
|
||||
# If the exit code is empty, treat it as a failure (❌)
|
||||
echo "Test first message: $([[ -z $TEST_FIRST_MESSAGE_EXIT_CODE || $TEST_FIRST_MESSAGE_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
|
||||
echo "Test model sends message with keyword: $([[ -z $TEST_KEYWORD_MESSAGE_EXIT_CODE || $TEST_KEYWORD_MESSAGE_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
|
||||
echo "Test model uses external tool: $([[ -z $TEST_EXTERNAL_TOOL_EXIT_CODE || $TEST_EXTERNAL_TOOL_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
|
||||
echo "Test model recalls chat memory: $([[ -z $TEST_CHAT_MEMORY_EXIT_CODE || $TEST_CHAT_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
|
||||
echo "Test model uses 'archival_memory_search' to find secret: $([[ -z $TEST_ARCHIVAL_MEMORY_EXIT_CODE || $TEST_ARCHIVAL_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
|
||||
echo "Test model can edit core memories: $([[ -z $TEST_CORE_MEMORY_EXIT_CODE || $TEST_CORE_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
|
||||
|
||||
# Check if any test failed (either non-zero or unset exit code)
|
||||
if [[ -z $TEST_FIRST_MESSAGE_EXIT_CODE || $TEST_FIRST_MESSAGE_EXIT_CODE -ne 0 || \
|
||||
-z $TEST_KEYWORD_MESSAGE_EXIT_CODE || $TEST_KEYWORD_MESSAGE_EXIT_CODE -ne 0 || \
|
||||
-z $TEST_EXTERNAL_TOOL_EXIT_CODE || $TEST_EXTERNAL_TOOL_EXIT_CODE -ne 0 || \
|
||||
-z $TEST_CHAT_MEMORY_EXIT_CODE || $TEST_CHAT_MEMORY_EXIT_CODE -ne 0 || \
|
||||
-z $TEST_ARCHIVAL_MEMORY_EXIT_CODE || $TEST_ARCHIVAL_MEMORY_EXIT_CODE -ne 0 || \
|
||||
-z $TEST_CORE_MEMORY_EXIT_CODE || $TEST_CORE_MEMORY_EXIT_CODE -ne 0 ]]; then
|
||||
echo "Some tests failed."
|
||||
exit 78
|
||||
fi
|
||||
continue-on-error: true
|
||||
@@ -30,7 +30,7 @@ def roll_d20() -> str:
|
||||
|
||||
|
||||
# create a tool from the function
|
||||
tool = client.create_tool(roll_d20)
|
||||
tool = client.create_or_update_tool(roll_d20)
|
||||
print(f"Created tool with name {tool.name}")
|
||||
|
||||
# create a new agent
|
||||
|
||||
@@ -370,7 +370,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"birthday_tool = client.create_tool(query_birthday_db)"
|
||||
"birthday_tool = client.create_or_update_tool(query_birthday_db)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -181,8 +181,8 @@
|
||||
"\n",
|
||||
"# TODO: add an archival andidate tool (provide justification) \n",
|
||||
"\n",
|
||||
"read_resume_tool = client.create_tool(read_resume) \n",
|
||||
"submit_evaluation_tool = client.create_tool(submit_evaluation)"
|
||||
"read_resume_tool = client.create_or_update_tool(read_resume) \n",
|
||||
"submit_evaluation_tool = client.create_or_update_tool(submit_evaluation)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -239,7 +239,7 @@
|
||||
" print(\"Pretend to email:\", content)\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
"email_candidate_tool = client.create_tool(email_candidate)"
|
||||
"email_candidate_tool = client.create_or_update_tool(email_candidate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -713,8 +713,8 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"# create tools \n",
|
||||
"search_candidate_tool = client.create_tool(search_candidates_db)\n",
|
||||
"consider_candidate_tool = client.create_tool(consider_candidate)\n",
|
||||
"search_candidate_tool = client.create_or_update_tool(search_candidates_db)\n",
|
||||
"consider_candidate_tool = client.create_or_update_tool(consider_candidate)\n",
|
||||
"\n",
|
||||
"# delete agent if exists \n",
|
||||
"if client.get_agent_id(\"recruiter_agent\"): \n",
|
||||
|
||||
@@ -48,8 +48,8 @@ swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provide
|
||||
swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
|
||||
|
||||
# create tools
|
||||
transfer_a = swarm.client.create_tool(transfer_agent_a)
|
||||
transfer_b = swarm.client.create_tool(transfer_agent_b)
|
||||
transfer_a = swarm.client.create_or_update_tool(transfer_agent_a)
|
||||
transfer_b = swarm.client.create_or_update_tool(transfer_agent_b)
|
||||
|
||||
# create agents
|
||||
if swarm.client.get_agent_id("agentb"):
|
||||
|
||||
@@ -93,7 +93,7 @@ def main():
|
||||
functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error]
|
||||
tools = []
|
||||
for func in functions:
|
||||
tool = client.create_tool(func)
|
||||
tool = client.create_or_update_tool(func)
|
||||
tools.append(tool)
|
||||
tool_names = [t.name for t in tools[:-1]]
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ def add_tool(
|
||||
func = eval(func_def.name)
|
||||
|
||||
# 4. Add or update the tool
|
||||
tool = client.create_tool(func=func, name=name, tags=tags, update=update)
|
||||
tool = client.create_or_update_tool(func=func, name=name, tags=tags, update=update)
|
||||
print(f"Tool {tool.name} added successfully")
|
||||
|
||||
|
||||
|
||||
@@ -211,6 +211,14 @@ class AbstractClient(object):
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_or_update_tool(
|
||||
self,
|
||||
func,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_tool(
|
||||
self,
|
||||
id: str,
|
||||
@@ -532,7 +540,7 @@ class RESTClient(AbstractClient):
|
||||
# add memory tools
|
||||
memory_functions = get_memory_functions(memory)
|
||||
for func_name, func in memory_functions.items():
|
||||
tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
|
||||
tool = self.create_or_update_tool(func, name=func_name, tags=["memory", "letta-base"])
|
||||
tool_names.append(tool.name)
|
||||
|
||||
# check if default configs are provided
|
||||
@@ -1440,12 +1448,6 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
"""
|
||||
|
||||
# TODO: check tool update code
|
||||
# TODO: check if tool already exists
|
||||
|
||||
# TODO: how to load modules?
|
||||
# parse source code/schema
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
|
||||
@@ -1456,6 +1458,33 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
|
||||
def create_or_update_tool(
|
||||
self,
|
||||
func: Callable,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
|
||||
Args:
|
||||
func (callable): The function to create a tool for.
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
"""
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
|
||||
# call server function
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
|
||||
response = requests.put(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
|
||||
def update_tool(
|
||||
self,
|
||||
id: str,
|
||||
@@ -1489,45 +1518,6 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to update tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
|
||||
# def create_tool(
|
||||
# self,
|
||||
# func,
|
||||
# name: Optional[str] = None,
|
||||
# update: Optional[bool] = True, # TODO: actually use this
|
||||
# tags: Optional[List[str]] = None,
|
||||
# ):
|
||||
# """Create a tool
|
||||
|
||||
# Args:
|
||||
# func (callable): The function to create a tool for.
|
||||
# tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
# update (bool, optional): Update the tool if it already exists. Defaults to True.
|
||||
|
||||
# Returns:
|
||||
# Tool object
|
||||
# """
|
||||
|
||||
# # TODO: check if tool already exists
|
||||
# # TODO: how to load modules?
|
||||
# # parse source code/schema
|
||||
# source_code = parse_source_code(func)
|
||||
# json_schema = generate_schema(func, name)
|
||||
# source_type = "python"
|
||||
# json_schema["name"]
|
||||
|
||||
# # create data
|
||||
# data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema, "update": update}
|
||||
# try:
|
||||
# CreateToolRequest(**data) # validate data
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Failed to create tool: {e}, invalid input {data}")
|
||||
|
||||
# # make REST request
|
||||
# response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=data, headers=self.headers)
|
||||
# if response.status_code != 200:
|
||||
# raise ValueError(f"Failed to create tool: {response.text}")
|
||||
# return ToolModel(**response.json())
|
||||
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
|
||||
"""
|
||||
List available tools for the user.
|
||||
@@ -1977,7 +1967,7 @@ class LocalClient(AbstractClient):
|
||||
# add memory tools
|
||||
memory_functions = get_memory_functions(memory)
|
||||
for func_name, func in memory_functions.items():
|
||||
tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
|
||||
tool = self.create_or_update_tool(func, name=func_name, tags=["memory", "letta-base"])
|
||||
tool_names.append(tool.name)
|
||||
|
||||
self.interface.clear()
|
||||
@@ -2573,7 +2563,6 @@ class LocalClient(AbstractClient):
|
||||
tool_create = ToolCreate.from_composio(action=action)
|
||||
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
|
||||
|
||||
# TODO: Use the above function `add_tool` here as there is duplicate logic
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
@@ -2601,6 +2590,42 @@ class LocalClient(AbstractClient):
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
# call server function
|
||||
return self.server.tool_manager.create_tool(
|
||||
Tool(
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
name=name,
|
||||
tags=tags,
|
||||
description=description,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
def create_or_update_tool(
|
||||
self,
|
||||
func,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
|
||||
Args:
|
||||
func (callable): The function to create a tool for.
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
description (str, optional): The description.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
"""
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
# call server function
|
||||
return self.server.tool_manager.create_or_update_tool(
|
||||
Tool(
|
||||
|
||||
@@ -4,3 +4,11 @@ class NoResultFound(Exception):
|
||||
|
||||
class MalformedIdError(Exception):
|
||||
"""An id not in the right format, most likely violating uuid4 format."""
|
||||
|
||||
|
||||
class UniqueConstraintViolationError(ValueError):
|
||||
"""Custom exception for unique constraint violations."""
|
||||
|
||||
|
||||
class ForeignKeyConstraintViolationError(ValueError):
|
||||
"""Custom exception for foreign key constraint violations."""
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
||||
|
||||
from sqlalchemy import String, select
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.errors import (
|
||||
ForeignKeyConstraintViolationError,
|
||||
NoResultFound,
|
||||
UniqueConstraintViolationError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
@@ -102,12 +107,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
|
||||
with db_session as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
try:
|
||||
with db_session as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
except DBAPIError as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
@@ -168,6 +175,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
raise ValueError(f"object {actor} has no organization accessor")
|
||||
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
|
||||
|
||||
@classmethod
|
||||
def _handle_dbapi_error(cls, e: DBAPIError):
|
||||
"""Handle database errors and raise appropriate custom exceptions."""
|
||||
orig = e.orig # Extract the original error from the DBAPIError
|
||||
error_code = None
|
||||
|
||||
# For psycopg2
|
||||
if hasattr(orig, "pgcode"):
|
||||
error_code = orig.pgcode
|
||||
# For pg8000
|
||||
elif hasattr(orig, "args") and len(orig.args) > 0:
|
||||
# The first argument contains the error details as a dictionary
|
||||
err_dict = orig.args[0]
|
||||
if isinstance(err_dict, dict):
|
||||
error_code = err_dict.get("C") # 'C' is the error code field
|
||||
logger.info(f"Extracted error_code: {error_code}")
|
||||
|
||||
# Handle unique constraint violations
|
||||
if error_code == "23505":
|
||||
raise UniqueConstraintViolationError(
|
||||
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
|
||||
) from e
|
||||
|
||||
# Handle foreign key violations
|
||||
if error_code == "23503":
|
||||
raise ForeignKeyConstraintViolationError(
|
||||
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
|
||||
) from e
|
||||
|
||||
# Re-raise for other unhandled DBAPI errors
|
||||
raise
|
||||
|
||||
@property
|
||||
def __pydantic_model__(self) -> Type["BaseModel"]:
|
||||
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.helpers import (
|
||||
generate_composio_tool_wrapper,
|
||||
generate_langchain_tool_wrapper,
|
||||
@@ -44,6 +45,29 @@ class Tool(BaseTool):
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def populate_missing_fields(self):
|
||||
"""
|
||||
Populate missing fields: name, description, and json_schema.
|
||||
"""
|
||||
# Derive JSON schema if not provided
|
||||
if not self.json_schema:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
|
||||
# Derive name from the JSON schema if not provided
|
||||
if not self.name:
|
||||
# TODO: This in theory could error, but name should always be on json_schema
|
||||
# TODO: Make JSON schema a typed pydantic object
|
||||
self.name = self.json_schema.get("name")
|
||||
|
||||
# Derive description from the JSON schema if not provided
|
||||
if not self.description:
|
||||
# TODO: This in theory could error, but description should always be on json_schema
|
||||
# TODO: Make JSON schema a typed pydantic object
|
||||
self.description = self.json_schema.get("description")
|
||||
|
||||
return self
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert tool into OpenAI representation.
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
@@ -83,12 +84,41 @@ def create_tool(
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
# Derive user and org id from actor
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log or print the full exception here for debugging
|
||||
print(f"Error occurred: {e}")
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
print(f"Unexpected error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
# Send request to create the tool
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_or_update_tool(pydantic_tool=tool, actor=actor)
|
||||
|
||||
@router.put("/", response_model=Tool, operation_id="upsert_tool")
|
||||
def upsert_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Create or update a tool
|
||||
"""
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log the error and raise a conflict exception
|
||||
print(f"Unique constraint violation occurred: {e}")
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
print(f"Unexpected error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
@router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool")
|
||||
|
||||
@@ -35,9 +35,7 @@ class ToolManager:
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Derive json_schema
|
||||
derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(source_code=pydantic_tool.source_code)
|
||||
derived_name = pydantic_tool.name or derived_json_schema["name"]
|
||||
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
|
||||
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_tool.model_dump(exclude={"module"}, exclude_unset=True, exclude_none=True)
|
||||
@@ -52,8 +50,6 @@ class ToolManager:
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
)
|
||||
else:
|
||||
pydantic_tool.json_schema = derived_json_schema
|
||||
pydantic_tool.name = derived_name
|
||||
tool = self.create_tool(pydantic_tool, actor=actor)
|
||||
|
||||
return tool
|
||||
@@ -61,18 +57,15 @@ class ToolManager:
|
||||
@enforce_types
|
||||
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Create the tool
|
||||
with self.session_maker() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump()
|
||||
tool = ToolModel(**tool_data)
|
||||
# The description is most likely auto-generated via the json_schema,
|
||||
# so copy it over into the top-level description field
|
||||
if tool.description is None:
|
||||
tool.description = tool.json_schema.get("description", None)
|
||||
tool.create(session, actor=actor)
|
||||
|
||||
tool.create(session, actor=actor) # Re-raise other database-related errors
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -97,11 +97,11 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Add tools
|
||||
t1 = client.create_tool(first_secret_word)
|
||||
t2 = client.create_tool(second_secret_word)
|
||||
t3 = client.create_tool(third_secret_word)
|
||||
t4 = client.create_tool(fourth_secret_word)
|
||||
t_err = client.create_tool(auto_error)
|
||||
t1 = client.create_or_update_tool(first_secret_word)
|
||||
t2 = client.create_or_update_tool(second_secret_word)
|
||||
t3 = client.create_or_update_tool(third_secret_word)
|
||||
t4 = client.create_or_update_tool(fourth_secret_word)
|
||||
t_err = client.create_or_update_tool(auto_error)
|
||||
tools = [t1, t2, t3, t4, t_err]
|
||||
|
||||
# Make tool rules
|
||||
|
||||
@@ -284,7 +284,7 @@ def test_tools(client: LocalClient):
|
||||
print(msg)
|
||||
|
||||
# create tool
|
||||
tool = client.create_tool(func=print_tool, tags=["extras"])
|
||||
tool = client.create_or_update_tool(func=print_tool, tags=["extras"])
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
@@ -17,6 +16,10 @@ from letta.orm import (
|
||||
User,
|
||||
)
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.errors import (
|
||||
ForeignKeyConstraintViolationError,
|
||||
UniqueConstraintViolationError,
|
||||
)
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
@@ -148,7 +151,7 @@ def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fixture(server: SyncServer, default_user, default_organization):
|
||||
def print_tool(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
|
||||
def print_tool(message: str):
|
||||
@@ -177,8 +180,8 @@ def tool_fixture(server: SyncServer, default_user, default_organization):
|
||||
|
||||
tool = server.tool_manager.create_tool(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool, organization, and user for use in tests
|
||||
yield {"tool": tool}
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -340,76 +343,75 @@ def test_update_user(server: SyncServer):
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tool Manager Tests
|
||||
# ToolManager Tests
|
||||
# ======================================================================================================================
|
||||
def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization):
|
||||
tool = tool_fixture["tool"]
|
||||
|
||||
def test_create_tool(server: SyncServer, print_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert tool.created_by_id == default_user.id
|
||||
assert tool.organization_id == default_organization.id
|
||||
assert print_tool.created_by_id == default_user.id
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
tool = PydanticTool(**data)
|
||||
|
||||
with pytest.raises(UniqueConstraintViolationError):
|
||||
server.tool_manager.create_tool(tool, actor=default_user)
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
# Fetch the tool by ID using the manager method
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
assert fetched_tool.name == tool.name
|
||||
assert fetched_tool.description == tool.description
|
||||
assert fetched_tool.tags == tool.tags
|
||||
assert fetched_tool.source_code == tool.source_code
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
assert fetched_tool.id == print_tool.id
|
||||
assert fetched_tool.name == print_tool.name
|
||||
assert fetched_tool.description == print_tool.description
|
||||
assert fetched_tool.tags == print_tool.tags
|
||||
assert fetched_tool.source_code == print_tool.source_code
|
||||
assert fetched_tool.source_type == print_tool.source_type
|
||||
|
||||
|
||||
def test_get_tool_with_actor(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
|
||||
# Fetch the tool by name and organization ID
|
||||
fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=default_user)
|
||||
def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
|
||||
# Fetch the print_tool by name and organization ID
|
||||
fetched_tool = server.tool_manager.get_tool_by_name(print_tool.name, actor=default_user)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
assert fetched_tool.name == tool.name
|
||||
assert fetched_tool.id == print_tool.id
|
||||
assert fetched_tool.name == print_tool.name
|
||||
assert fetched_tool.created_by_id == default_user.id
|
||||
assert fetched_tool.description == tool.description
|
||||
assert fetched_tool.tags == tool.tags
|
||||
assert fetched_tool.source_code == tool.source_code
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
assert fetched_tool.description == print_tool.description
|
||||
assert fetched_tool.tags == print_tool.tags
|
||||
assert fetched_tool.source_code == print_tool.source_code
|
||||
assert fetched_tool.source_type == print_tool.source_type
|
||||
|
||||
|
||||
def test_list_tools(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
|
||||
def test_list_tools(server: SyncServer, print_tool, default_user):
|
||||
# List tools (should include the one created by the fixture)
|
||||
tools = server.tool_manager.list_tools(actor=default_user)
|
||||
|
||||
# Assertions to check that the created tool is listed
|
||||
assert len(tools) == 1
|
||||
assert any(t.id == tool.id for t in tools)
|
||||
assert any(t.id == print_tool.id for t in tools)
|
||||
|
||||
|
||||
def test_update_tool_by_id(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
updated_description = "updated_description"
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's description
|
||||
# Create a ToolUpdate object to modify the print_tool's description
|
||||
tool_update = ToolUpdate(description=updated_description)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
|
||||
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the update was successful
|
||||
assert updated_tool.description == updated_description
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture, default_user):
|
||||
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):
|
||||
def counter_tool(counter: int):
|
||||
"""
|
||||
Args:
|
||||
@@ -424,8 +426,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
|
||||
return True
|
||||
|
||||
# Test begins
|
||||
tool = tool_fixture["tool"]
|
||||
og_json_schema = tool.json_schema
|
||||
og_json_schema = print_tool.json_schema
|
||||
|
||||
source_code = parse_source_code(counter_tool)
|
||||
|
||||
@@ -433,10 +434,10 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
|
||||
tool_update = ToolUpdate(source_code=source_code)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
|
||||
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the update was successful, and json_schema is updated as well
|
||||
assert updated_tool.source_code == source_code
|
||||
@@ -446,7 +447,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
|
||||
assert updated_tool.json_schema == new_schema
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture, default_user):
|
||||
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user):
|
||||
def counter_tool(counter: int):
|
||||
"""
|
||||
Args:
|
||||
@@ -461,8 +462,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
|
||||
return True
|
||||
|
||||
# Test begins
|
||||
tool = tool_fixture["tool"]
|
||||
og_json_schema = tool.json_schema
|
||||
og_json_schema = print_tool.json_schema
|
||||
|
||||
source_code = parse_source_code(counter_tool)
|
||||
name = "counter_tool"
|
||||
@@ -471,10 +471,10 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
|
||||
tool_update = ToolUpdate(name=name, source_code=source_code)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
|
||||
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the update was successful, and json_schema is updated as well
|
||||
assert updated_tool.source_code == source_code
|
||||
@@ -485,29 +485,26 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
|
||||
assert updated_tool.name == name
|
||||
|
||||
|
||||
def test_update_tool_multi_user(server: SyncServer, tool_fixture, default_user, other_user):
|
||||
tool = tool_fixture["tool"]
|
||||
def test_update_tool_multi_user(server: SyncServer, print_tool, default_user, other_user):
|
||||
updated_description = "updated_description"
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's description
|
||||
# Create a ToolUpdate object to modify the print_tool's description
|
||||
tool_update = ToolUpdate(description=updated_description)
|
||||
|
||||
# Update the tool using the manager method, but WITH THE OTHER USER'S ID!
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=other_user)
|
||||
# Update the print_tool using the manager method, but WITH THE OTHER USER'S ID!
|
||||
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=other_user)
|
||||
|
||||
# Check that the created_by and last_updated_by fields are correct
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
# Fetch the updated print_tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
assert updated_tool.last_updated_by_id == other_user.id
|
||||
assert updated_tool.created_by_id == default_user.id
|
||||
|
||||
|
||||
def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
|
||||
# Delete the tool using the manager method
|
||||
server.tool_manager.delete_tool_by_id(tool.id, actor=default_user)
|
||||
def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
# Delete the print_tool using the manager method
|
||||
server.tool_manager.delete_tool_by_id(print_tool.id, actor=default_user)
|
||||
|
||||
tools = server.tool_manager.list_tools(actor=default_user)
|
||||
assert len(tools) == 0
|
||||
@@ -1067,7 +1064,7 @@ def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
|
||||
|
||||
|
||||
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
|
||||
with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.blocks_agents_manager.add_block_to_agent(
|
||||
agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label"
|
||||
)
|
||||
@@ -1131,5 +1128,5 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
|
||||
block_manager = BlockManager()
|
||||
block_manager.delete_block(block_id=default_block.id, actor=default_user)
|
||||
|
||||
with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
@@ -12,8 +12,8 @@ def test_o1_agent():
|
||||
client = create_client()
|
||||
assert client is not None
|
||||
|
||||
thinking_tool = client.create_tool(send_thinking_message)
|
||||
final_tool = client.create_tool(send_final_message)
|
||||
thinking_tool = client.create_or_update_tool(send_thinking_message)
|
||||
final_tool = client.create_or_update_tool(send_final_message)
|
||||
|
||||
agent_state = client.create_agent(
|
||||
agent_type=AgentType.o1_agent,
|
||||
|
||||
Reference in New Issue
Block a user