chore: Change create_tool endpoint on v1 routes to error instead of upsert (#2102)

This commit is contained in:
Matthew Zhou
2024-11-25 10:46:15 -08:00
committed by GitHub
parent f237717ce4
commit 8711e1dc00
17 changed files with 271 additions and 259 deletions

View File

@@ -1,104 +0,0 @@
name: Groq Llama 3.1 70b Capabilities Test
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v4
- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E external-tools"
- name: Test first message contains expected function call and inner monologue
id: test_first_message
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_returns_valid_first_message
echo "TEST_FIRST_MESSAGE_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true
- name: Test model sends message with keyword
id: test_keyword_message
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_returns_keyword
echo "TEST_KEYWORD_MESSAGE_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true
- name: Test model uses external tool correctly
id: test_external_tool
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_uses_external_tool
echo "TEST_EXTERNAL_TOOL_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true
- name: Test model recalls chat memory
id: test_chat_memory
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_recall_chat_memory
echo "TEST_CHAT_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true
- name: Test model uses 'archival_memory_search' to find secret
id: test_archival_memory
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_archival_memory_retrieval
echo "TEST_ARCHIVAL_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true
- name: Test model can edit core memories
id: test_core_memory
env:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_groq_llama31_70b_edit_core_memory
echo "TEST_CORE_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true
- name: Summarize test results
if: always()
run: |
echo "Test Results Summary:"
# If the exit code is empty, treat it as a failure (❌)
echo "Test first message: $([[ -z $TEST_FIRST_MESSAGE_EXIT_CODE || $TEST_FIRST_MESSAGE_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model sends message with keyword: $([[ -z $TEST_KEYWORD_MESSAGE_EXIT_CODE || $TEST_KEYWORD_MESSAGE_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model uses external tool: $([[ -z $TEST_EXTERNAL_TOOL_EXIT_CODE || $TEST_EXTERNAL_TOOL_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model recalls chat memory: $([[ -z $TEST_CHAT_MEMORY_EXIT_CODE || $TEST_CHAT_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model uses 'archival_memory_search' to find secret: $([[ -z $TEST_ARCHIVAL_MEMORY_EXIT_CODE || $TEST_ARCHIVAL_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model can edit core memories: $([[ -z $TEST_CORE_MEMORY_EXIT_CODE || $TEST_CORE_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
# Check if any test failed (either non-zero or unset exit code)
if [[ -z $TEST_FIRST_MESSAGE_EXIT_CODE || $TEST_FIRST_MESSAGE_EXIT_CODE -ne 0 || \
-z $TEST_KEYWORD_MESSAGE_EXIT_CODE || $TEST_KEYWORD_MESSAGE_EXIT_CODE -ne 0 || \
-z $TEST_EXTERNAL_TOOL_EXIT_CODE || $TEST_EXTERNAL_TOOL_EXIT_CODE -ne 0 || \
-z $TEST_CHAT_MEMORY_EXIT_CODE || $TEST_CHAT_MEMORY_EXIT_CODE -ne 0 || \
-z $TEST_ARCHIVAL_MEMORY_EXIT_CODE || $TEST_ARCHIVAL_MEMORY_EXIT_CODE -ne 0 || \
-z $TEST_CORE_MEMORY_EXIT_CODE || $TEST_CORE_MEMORY_EXIT_CODE -ne 0 ]]; then
echo "Some tests failed."
exit 78
fi
continue-on-error: true

View File

@@ -30,7 +30,7 @@ def roll_d20() -> str:
# create a tool from the function
tool = client.create_tool(roll_d20)
tool = client.create_or_update_tool(roll_d20)
print(f"Created tool with name {tool.name}")
# create a new agent

View File

@@ -370,7 +370,7 @@
"metadata": {},
"outputs": [],
"source": [
"birthday_tool = client.create_tool(query_birthday_db)"
"birthday_tool = client.create_or_update_tool(query_birthday_db)"
]
},
{

View File

@@ -181,8 +181,8 @@
"\n",
"# TODO: add an archival andidate tool (provide justification) \n",
"\n",
"read_resume_tool = client.create_tool(read_resume) \n",
"submit_evaluation_tool = client.create_tool(submit_evaluation)"
"read_resume_tool = client.create_or_update_tool(read_resume) \n",
"submit_evaluation_tool = client.create_or_update_tool(submit_evaluation)"
]
},
{
@@ -239,7 +239,7 @@
" print(\"Pretend to email:\", content)\n",
" return\n",
"\n",
"email_candidate_tool = client.create_tool(email_candidate)"
"email_candidate_tool = client.create_or_update_tool(email_candidate)"
]
},
{
@@ -713,8 +713,8 @@
"\n",
"\n",
"# create tools \n",
"search_candidate_tool = client.create_tool(search_candidates_db)\n",
"consider_candidate_tool = client.create_tool(consider_candidate)\n",
"search_candidate_tool = client.create_or_update_tool(search_candidates_db)\n",
"consider_candidate_tool = client.create_or_update_tool(consider_candidate)\n",
"\n",
"# delete agent if exists \n",
"if client.get_agent_id(\"recruiter_agent\"): \n",

View File

@@ -48,8 +48,8 @@ swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provide
swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
# create tools
transfer_a = swarm.client.create_tool(transfer_agent_a)
transfer_b = swarm.client.create_tool(transfer_agent_b)
transfer_a = swarm.client.create_or_update_tool(transfer_agent_a)
transfer_b = swarm.client.create_or_update_tool(transfer_agent_b)
# create agents
if swarm.client.get_agent_id("agentb"):

View File

@@ -93,7 +93,7 @@ def main():
functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error]
tools = []
for func in functions:
tool = client.create_tool(func)
tool = client.create_or_update_tool(func)
tools.append(tool)
tool_names = [t.name for t in tools[:-1]]

View File

@@ -136,7 +136,7 @@ def add_tool(
func = eval(func_def.name)
# 4. Add or update the tool
tool = client.create_tool(func=func, name=name, tags=tags, update=update)
tool = client.create_or_update_tool(func=func, name=name, tags=tags, update=update)
print(f"Tool {tool.name} added successfully")

View File

@@ -211,6 +211,14 @@ class AbstractClient(object):
) -> Tool:
raise NotImplementedError
def create_or_update_tool(
self,
func,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> Tool:
raise NotImplementedError
def update_tool(
self,
id: str,
@@ -532,7 +540,7 @@ class RESTClient(AbstractClient):
# add memory tools
memory_functions = get_memory_functions(memory)
for func_name, func in memory_functions.items():
tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
tool = self.create_or_update_tool(func, name=func_name, tags=["memory", "letta-base"])
tool_names.append(tool.name)
# check if default configs are provided
@@ -1440,12 +1448,6 @@ class RESTClient(AbstractClient):
Returns:
tool (Tool): The created tool.
"""
# TODO: check tool update code
# TODO: check if tool already exists
# TODO: how to load modules?
# parse source code/schema
source_code = parse_source_code(func)
source_type = "python"
@@ -1456,6 +1458,33 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to create tool: {response.text}")
return Tool(**response.json())
def create_or_update_tool(
self,
func: Callable,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> Tool:
"""
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
Args:
func (callable): The function to create a tool for.
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
Returns:
tool (Tool): The created tool.
"""
source_code = parse_source_code(func)
source_type = "python"
# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
response = requests.put(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return Tool(**response.json())
def update_tool(
self,
id: str,
@@ -1489,45 +1518,6 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to update tool: {response.text}")
return Tool(**response.json())
# def create_tool(
# self,
# func,
# name: Optional[str] = None,
# update: Optional[bool] = True, # TODO: actually use this
# tags: Optional[List[str]] = None,
# ):
# """Create a tool
# Args:
# func (callable): The function to create a tool for.
# tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
# update (bool, optional): Update the tool if it already exists. Defaults to True.
# Returns:
# Tool object
# """
# # TODO: check if tool already exists
# # TODO: how to load modules?
# # parse source code/schema
# source_code = parse_source_code(func)
# json_schema = generate_schema(func, name)
# source_type = "python"
# json_schema["name"]
# # create data
# data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema, "update": update}
# try:
# CreateToolRequest(**data) # validate data
# except Exception as e:
# raise ValueError(f"Failed to create tool: {e}, invalid input {data}")
# # make REST request
# response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=data, headers=self.headers)
# if response.status_code != 200:
# raise ValueError(f"Failed to create tool: {response.text}")
# return ToolModel(**response.json())
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
List available tools for the user.
@@ -1977,7 +1967,7 @@ class LocalClient(AbstractClient):
# add memory tools
memory_functions = get_memory_functions(memory)
for func_name, func in memory_functions.items():
tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
tool = self.create_or_update_tool(func, name=func_name, tags=["memory", "letta-base"])
tool_names.append(tool.name)
self.interface.clear()
@@ -2573,7 +2563,6 @@ class LocalClient(AbstractClient):
tool_create = ToolCreate.from_composio(action=action)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
# TODO: Use the above function `add_tool` here as there is duplicate logic
def create_tool(
self,
func,
@@ -2601,6 +2590,42 @@ class LocalClient(AbstractClient):
if not tags:
tags = []
# call server function
return self.server.tool_manager.create_tool(
Tool(
source_type=source_type,
source_code=source_code,
name=name,
tags=tags,
description=description,
),
actor=self.user,
)
def create_or_update_tool(
self,
func,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
) -> Tool:
"""
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
Args:
func (callable): The function to create a tool for.
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
description (str, optional): The description.
Returns:
tool (Tool): The created tool.
"""
source_code = parse_source_code(func)
source_type = "python"
if not tags:
tags = []
# call server function
return self.server.tool_manager.create_or_update_tool(
Tool(

View File

@@ -4,3 +4,11 @@ class NoResultFound(Exception):
class MalformedIdError(Exception):
"""An id not in the right format, most likely violating uuid4 format."""
class UniqueConstraintViolationError(ValueError):
"""Custom exception for unique constraint violations."""
class ForeignKeyConstraintViolationError(ValueError):
"""Custom exception for foreign key constraint violations."""

View File

@@ -1,11 +1,16 @@
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from sqlalchemy import String, select
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Mapped, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import NoResultFound
from letta.orm.errors import (
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
)
if TYPE_CHECKING:
from pydantic import BaseModel
@@ -102,12 +107,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if actor:
self._set_created_and_updated_by_fields(actor.id)
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
try:
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
except DBAPIError as e:
self._handle_dbapi_error(e)
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -168,6 +175,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
@classmethod
def _handle_dbapi_error(cls, e: DBAPIError):
"""Handle database errors and raise appropriate custom exceptions."""
orig = e.orig # Extract the original error from the DBAPIError
error_code = None
# For psycopg2
if hasattr(orig, "pgcode"):
error_code = orig.pgcode
# For pg8000
elif hasattr(orig, "args") and len(orig.args) > 0:
# The first argument contains the error details as a dictionary
err_dict = orig.args[0]
if isinstance(err_dict, dict):
error_code = err_dict.get("C") # 'C' is the error code field
logger.info(f"Extracted error_code: {error_code}")
# Handle unique constraint violations
if error_code == "23505":
raise UniqueConstraintViolationError(
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
) from e
# Handle foreign key violations
if error_code == "23503":
raise ForeignKeyConstraintViolationError(
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
) from e
# Re-raise for other unhandled DBAPI errors
raise
@property
def __pydantic_model__(self) -> Type["BaseModel"]:
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")

View File

@@ -1,7 +1,8 @@
from typing import Dict, List, Optional
from pydantic import Field
from pydantic import Field, model_validator
from letta.functions.functions import derive_openai_json_schema
from letta.functions.helpers import (
generate_composio_tool_wrapper,
generate_langchain_tool_wrapper,
@@ -44,6 +45,29 @@ class Tool(BaseTool):
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
@model_validator(mode="after")
def populate_missing_fields(self):
"""
Populate missing fields: name, description, and json_schema.
"""
# Derive JSON schema if not provided
if not self.json_schema:
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
# Derive name from the JSON schema if not provided
if not self.name:
# TODO: This in theory could error, but name should always be on json_schema
# TODO: Make JSON schema a typed pydantic object
self.name = self.json_schema.get("name")
# Derive description from the JSON schema if not provided
if not self.description:
# TODO: This in theory could error, but description should always be on json_schema
# TODO: Make JSON schema a typed pydantic object
self.description = self.json_schema.get("description")
return self
def to_dict(self):
"""
Convert tool into OpenAI representation.

View File

@@ -2,6 +2,7 @@ from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@@ -83,12 +84,41 @@ def create_tool(
"""
Create a new tool
"""
# Derive user and org id from actor
actor = server.get_user_or_default(user_id=user_id)
try:
actor = server.get_user_or_default(user_id=user_id)
tool = Tool(**request.model_dump())
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
except UniqueConstraintViolationError as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
# Catch other unexpected errors and raise an internal server error
print(f"Unexpected error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
# Send request to create the tool
tool = Tool(**request.model_dump())
return server.tool_manager.create_or_update_tool(pydantic_tool=tool, actor=actor)
@router.put("/", response_model=Tool, operation_id="upsert_tool")
def upsert_tool(
request: ToolCreate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Create or update a tool
"""
try:
actor = server.get_user_or_default(user_id=user_id)
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
return tool
except UniqueConstraintViolationError as e:
# Log the error and raise a conflict exception
print(f"Unique constraint violation occurred: {e}")
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
# Catch other unexpected errors and raise an internal server error
print(f"Unexpected error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
@router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool")

View File

@@ -35,9 +35,7 @@ class ToolManager:
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Derive json_schema
derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(source_code=pydantic_tool.source_code)
derived_name = pydantic_tool.name or derived_json_schema["name"]
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool:
# Put to dict and remove fields that should not be reset
update_data = pydantic_tool.model_dump(exclude={"module"}, exclude_unset=True, exclude_none=True)
@@ -52,8 +50,6 @@ class ToolManager:
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
)
else:
pydantic_tool.json_schema = derived_json_schema
pydantic_tool.name = derived_name
tool = self.create_tool(pydantic_tool, actor=actor)
return tool
@@ -61,18 +57,15 @@ class ToolManager:
@enforce_types
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Create the tool
with self.session_maker() as session:
# Set the organization id at the ORM layer
pydantic_tool.organization_id = actor.organization_id
# Auto-generate description if not provided
if pydantic_tool.description is None:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
tool_data = pydantic_tool.model_dump()
tool = ToolModel(**tool_data)
# The description is most likely auto-generated via the json_schema,
# so copy it over into the top-level description field
if tool.description is None:
tool.description = tool.json_schema.get("description", None)
tool.create(session, actor=actor)
tool.create(session, actor=actor) # Re-raise other database-related errors
return tool.to_pydantic()
@enforce_types

View File

@@ -97,11 +97,11 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):
cleanup(client=client, agent_uuid=agent_uuid)
# Add tools
t1 = client.create_tool(first_secret_word)
t2 = client.create_tool(second_secret_word)
t3 = client.create_tool(third_secret_word)
t4 = client.create_tool(fourth_secret_word)
t_err = client.create_tool(auto_error)
t1 = client.create_or_update_tool(first_secret_word)
t2 = client.create_or_update_tool(second_secret_word)
t3 = client.create_or_update_tool(third_secret_word)
t4 = client.create_or_update_tool(fourth_secret_word)
t_err = client.create_or_update_tool(auto_error)
tools = [t1, t2, t3, t4, t_err]
# Make tool rules

View File

@@ -284,7 +284,7 @@ def test_tools(client: LocalClient):
print(msg)
# create tool
tool = client.create_tool(func=print_tool, tags=["extras"])
tool = client.create_or_update_tool(func=print_tool, tags=["extras"])
# list tools
tools = client.list_tools()

View File

@@ -1,6 +1,5 @@
import pytest
from sqlalchemy import delete
from sqlalchemy.exc import DBAPIError
import letta.utils as utils
from letta.functions.functions import derive_openai_json_schema, parse_source_code
@@ -17,6 +16,10 @@ from letta.orm import (
User,
)
from letta.orm.agents_tags import AgentsTags
from letta.orm.errors import (
ForeignKeyConstraintViolationError,
UniqueConstraintViolationError,
)
from letta.schemas.agent import CreateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
@@ -148,7 +151,7 @@ def charles_agent(server: SyncServer, default_user, default_organization):
@pytest.fixture
def tool_fixture(server: SyncServer, default_user, default_organization):
def print_tool(server: SyncServer, default_user, default_organization):
"""Fixture to create a tool with default settings and clean up after the test."""
def print_tool(message: str):
@@ -177,8 +180,8 @@ def tool_fixture(server: SyncServer, default_user, default_organization):
tool = server.tool_manager.create_tool(tool, actor=default_user)
# Yield the created tool, organization, and user for use in tests
yield {"tool": tool}
# Yield the created tool
yield tool
@pytest.fixture
@@ -340,76 +343,75 @@ def test_update_user(server: SyncServer):
# ======================================================================================================================
# Tool Manager Tests
# ToolManager Tests
# ======================================================================================================================
def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization):
tool = tool_fixture["tool"]
def test_create_tool(server: SyncServer, print_tool, default_user, default_organization):
# Assertions to ensure the created tool matches the expected values
assert tool.created_by_id == default_user.id
assert tool.organization_id == default_organization.id
assert print_tool.created_by_id == default_user.id
assert print_tool.organization_id == default_organization.id
def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
data = print_tool.model_dump(exclude=["id"])
tool = PydanticTool(**data)
with pytest.raises(UniqueConstraintViolationError):
server.tool_manager.create_tool(tool, actor=default_user)
def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
# Fetch the tool by ID using the manager method
fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
fetched_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
assert fetched_tool.name == tool.name
assert fetched_tool.description == tool.description
assert fetched_tool.tags == tool.tags
assert fetched_tool.source_code == tool.source_code
assert fetched_tool.source_type == tool.source_type
assert fetched_tool.id == print_tool.id
assert fetched_tool.name == print_tool.name
assert fetched_tool.description == print_tool.description
assert fetched_tool.tags == print_tool.tags
assert fetched_tool.source_code == print_tool.source_code
assert fetched_tool.source_type == print_tool.source_type
def test_get_tool_with_actor(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
# Fetch the tool by name and organization ID
fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=default_user)
def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
# Fetch the print_tool by name and organization ID
fetched_tool = server.tool_manager.get_tool_by_name(print_tool.name, actor=default_user)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
assert fetched_tool.name == tool.name
assert fetched_tool.id == print_tool.id
assert fetched_tool.name == print_tool.name
assert fetched_tool.created_by_id == default_user.id
assert fetched_tool.description == tool.description
assert fetched_tool.tags == tool.tags
assert fetched_tool.source_code == tool.source_code
assert fetched_tool.source_type == tool.source_type
assert fetched_tool.description == print_tool.description
assert fetched_tool.tags == print_tool.tags
assert fetched_tool.source_code == print_tool.source_code
assert fetched_tool.source_type == print_tool.source_type
def test_list_tools(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
def test_list_tools(server: SyncServer, print_tool, default_user):
# List tools (should include the one created by the fixture)
tools = server.tool_manager.list_tools(actor=default_user)
# Assertions to check that the created tool is listed
assert len(tools) == 1
assert any(t.id == tool.id for t in tools)
assert any(t.id == print_tool.id for t in tools)
def test_update_tool_by_id(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
updated_description = "updated_description"
# Create a ToolUpdate object to modify the tool's description
# Create a ToolUpdate object to modify the print_tool's description
tool_update = ToolUpdate(description=updated_description)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
# Assertions to check if the update was successful
assert updated_tool.description == updated_description
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture, default_user):
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):
def counter_tool(counter: int):
"""
Args:
@@ -424,8 +426,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
return True
# Test begins
tool = tool_fixture["tool"]
og_json_schema = tool.json_schema
og_json_schema = print_tool.json_schema
source_code = parse_source_code(counter_tool)
@@ -433,10 +434,10 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
tool_update = ToolUpdate(source_code=source_code)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
# Assertions to check if the update was successful, and json_schema is updated as well
assert updated_tool.source_code == source_code
@@ -446,7 +447,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
assert updated_tool.json_schema == new_schema
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture, default_user):
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user):
def counter_tool(counter: int):
"""
Args:
@@ -461,8 +462,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
return True
# Test begins
tool = tool_fixture["tool"]
og_json_schema = tool.json_schema
og_json_schema = print_tool.json_schema
source_code = parse_source_code(counter_tool)
name = "counter_tool"
@@ -471,10 +471,10 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
tool_update = ToolUpdate(name=name, source_code=source_code)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
# Assertions to check if the update was successful, and json_schema is updated as well
assert updated_tool.source_code == source_code
@@ -485,29 +485,26 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
assert updated_tool.name == name
def test_update_tool_multi_user(server: SyncServer, tool_fixture, default_user, other_user):
tool = tool_fixture["tool"]
def test_update_tool_multi_user(server: SyncServer, print_tool, default_user, other_user):
updated_description = "updated_description"
# Create a ToolUpdate object to modify the tool's description
# Create a ToolUpdate object to modify the print_tool's description
tool_update = ToolUpdate(description=updated_description)
# Update the tool using the manager method, but WITH THE OTHER USER'S ID!
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=other_user)
# Update the print_tool using the manager method, but WITH THE OTHER USER'S ID!
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=other_user)
# Check that the created_by and last_updated_by fields are correct
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
# Fetch the updated print_tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
assert updated_tool.last_updated_by_id == other_user.id
assert updated_tool.created_by_id == default_user.id
def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
# Delete the tool using the manager method
server.tool_manager.delete_tool_by_id(tool.id, actor=default_user)
def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
# Delete the print_tool using the manager method
server.tool_manager.delete_tool_by_id(print_tool.id, actor=default_user)
tools = server.tool_manager.list_tools(actor=default_user)
assert len(tools) == 0
@@ -1067,7 +1064,7 @@ def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"):
with pytest.raises(ForeignKeyConstraintViolationError):
server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label"
)
@@ -1131,5 +1128,5 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
block_manager = BlockManager()
block_manager.delete_block(block_id=default_block.id, actor=default_user)
with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'):
with pytest.raises(ForeignKeyConstraintViolationError):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)

View File

@@ -12,8 +12,8 @@ def test_o1_agent():
client = create_client()
assert client is not None
thinking_tool = client.create_tool(send_thinking_message)
final_tool = client.create_tool(send_final_message)
thinking_tool = client.create_or_update_tool(send_thinking_message)
final_tool = client.create_or_update_tool(send_final_message)
agent_state = client.create_agent(
agent_type=AgentType.o1_agent,