feat: Adding init tool rule for Anthropic endpoint (#2262)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from letta.constants import (
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
O1_BASE_TOOLS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
STRUCTURED_OUTPUT_MODELS
|
||||
)
|
||||
from letta.errors import LLMError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -276,6 +277,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
self.model = self.agent_state.llm_config.model
|
||||
self.check_tool_rules()
|
||||
|
||||
# state managers
|
||||
self.block_manager = BlockManager()
|
||||
@@ -381,6 +383,14 @@ class Agent(BaseAgent):
|
||||
# Create the agent in the DB
|
||||
self.update_state()
|
||||
|
||||
def check_tool_rules(self):
|
||||
if self.model not in STRUCTURED_OUTPUT_MODELS:
|
||||
if len(self.tool_rules_solver.init_tool_rules) > 1:
|
||||
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
|
||||
self.supports_structured_output = False
|
||||
else:
|
||||
self.supports_structured_output = True
|
||||
|
||||
def update_memory_if_change(self, new_memory: Memory) -> bool:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
@@ -588,6 +598,7 @@ class Agent(BaseAgent):
|
||||
empty_response_retry_limit: int = 3,
|
||||
backoff_factor: float = 0.5, # delay multiplier for exponential backoff
|
||||
max_delay: float = 10.0, # max delay between retries
|
||||
step_count: Optional[int] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
|
||||
@@ -596,6 +607,16 @@ class Agent(BaseAgent):
|
||||
self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
|
||||
)
|
||||
|
||||
# For the first message, force the initial tool if one is specified
|
||||
force_tool_call = None
|
||||
if (
|
||||
step_count is not None
|
||||
and step_count == 0
|
||||
and not self.supports_structured_output
|
||||
and len(self.tool_rules_solver.init_tool_rules) > 0
|
||||
):
|
||||
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name
|
||||
|
||||
for attempt in range(1, empty_response_retry_limit + 1):
|
||||
try:
|
||||
response = create(
|
||||
@@ -606,6 +627,7 @@ class Agent(BaseAgent):
|
||||
functions_python=self.functions_python,
|
||||
function_call=function_call,
|
||||
first_message=first_message,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
stream_interface=self.interface,
|
||||
)
|
||||
@@ -897,6 +919,7 @@ class Agent(BaseAgent):
|
||||
step_count = 0
|
||||
while True:
|
||||
kwargs["first_message"] = False
|
||||
kwargs["step_count"] = step_count
|
||||
step_response = self.inner_step(
|
||||
messages=next_input_message,
|
||||
**kwargs,
|
||||
@@ -972,6 +995,7 @@ class Agent(BaseAgent):
|
||||
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
||||
skip_verify: bool = False,
|
||||
stream: bool = False, # TODO move to config?
|
||||
step_count: Optional[int] = None,
|
||||
) -> AgentStepResponse:
|
||||
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
||||
|
||||
@@ -1014,7 +1038,9 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
response = self._get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
first_message=first_message,
|
||||
stream=stream,
|
||||
step_count=step_count,
|
||||
)
|
||||
|
||||
# Step 3: check if LLM wanted to call a function
|
||||
|
||||
@@ -2156,6 +2156,7 @@ class LocalClient(AbstractClient):
|
||||
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
|
||||
"tool_ids": tool_ids,
|
||||
"tool_rules": tool_rules,
|
||||
"include_base_tools": include_base_tools,
|
||||
"system": system,
|
||||
"agent_type": agent_type,
|
||||
"llm_config": llm_config if llm_config else self._default_llm_config,
|
||||
|
||||
@@ -48,6 +48,9 @@ BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
DEFAULT_MESSAGE_TOOL = "send_message"
|
||||
DEFAULT_MESSAGE_TOOL_KWARG = "message"
|
||||
|
||||
# Structured output models
|
||||
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}
|
||||
|
||||
# LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level
|
||||
LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET}
|
||||
|
||||
|
||||
@@ -99,16 +99,20 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
||||
- 1 level less of nesting
|
||||
- "parameters" -> "input_schema"
|
||||
"""
|
||||
tools_dict_list = []
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
tools_dict_list.append(
|
||||
{
|
||||
"name": tool.function.name,
|
||||
"description": tool.function.description,
|
||||
"input_schema": tool.function.parameters,
|
||||
formatted_tool = {
|
||||
"name" : tool.function.name,
|
||||
"description" : tool.function.description,
|
||||
"input_schema" : tool.function.parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
)
|
||||
return tools_dict_list
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
|
||||
def merge_tool_results_into_user_messages(messages: List[dict]):
|
||||
|
||||
@@ -113,6 +113,7 @@ def create(
|
||||
function_call: str = "auto",
|
||||
# hint
|
||||
first_message: bool = False,
|
||||
force_tool_call: Optional[str] = None, # Force a specific tool to be called
|
||||
# use tool naming?
|
||||
# if false, will use deprecated 'functions' style
|
||||
use_tool_naming: bool = True,
|
||||
@@ -252,6 +253,16 @@ def create(
|
||||
if not use_tool_naming:
|
||||
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
|
||||
|
||||
tool_call = None
|
||||
if force_tool_call is not None:
|
||||
tool_call = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": force_tool_call
|
||||
}
|
||||
}
|
||||
assert functions is not None
|
||||
|
||||
return anthropic_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
@@ -259,7 +270,7 @@ def create(
|
||||
model=llm_config.model,
|
||||
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
||||
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
|
||||
# tool_choice=function_call,
|
||||
tool_choice=tool_call,
|
||||
# user=str(user_id),
|
||||
# NOTE: max_tokens is required for Anthropic API
|
||||
max_tokens=1024, # TODO make dynamic
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"context_window": 200000,
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"context_window": 200000,
|
||||
"model_wrapper": null,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 16385,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.letta_message import FunctionCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
@@ -127,3 +127,110 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
def test_check_tool_rules_with_different_models(mock_e2b_api_key_none):
|
||||
"""Test that tool rules are properly checked for different model configurations."""
|
||||
client = create_client()
|
||||
|
||||
config_files = [
|
||||
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
]
|
||||
|
||||
# Create two test tools
|
||||
t1_name = "first_secret_word"
|
||||
t2_name = "second_secret_word"
|
||||
t1 = client.create_or_update_tool(first_secret_word, name=t1_name)
|
||||
t2 = client.create_or_update_tool(second_secret_word, name=t2_name)
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=t1_name),
|
||||
InitToolRule(tool_name=t2_name)
|
||||
]
|
||||
tools = [t1, t2]
|
||||
|
||||
for config_file in config_files:
|
||||
# Setup tools
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
|
||||
if "gpt-4o" in config_file:
|
||||
# Structured output model (should work with multiple init tools)
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules)
|
||||
assert agent_state is not None
|
||||
else:
|
||||
# Non-structured output model (should raise error with multiple init tools)
|
||||
with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"):
|
||||
setup_agent(client, config_file, agent_uuid=agent_uuid,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules)
|
||||
|
||||
# Cleanup
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tool rule with single initial tool
|
||||
t3_name = "third_secret_word"
|
||||
t3 = client.create_or_update_tool(third_secret_word, name=t3_name)
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=t3_name)
|
||||
]
|
||||
tools = [t3]
|
||||
for config_file in config_files:
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
|
||||
# Structured output model (should work with single init tool)
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules)
|
||||
assert agent_state is not None
|
||||
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
|
||||
"""Test that the initial tool rule is enforced for the first message."""
|
||||
client = create_client()
|
||||
|
||||
# Create tool rules that require tool_a to be called first
|
||||
t1_name = "first_secret_word"
|
||||
t2_name = "second_secret_word"
|
||||
t1 = client.create_or_update_tool(first_secret_word, name=t1_name)
|
||||
t2 = client.create_or_update_tool(second_secret_word, name=t2_name)
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=t1_name),
|
||||
ChildToolRule(tool_name=t1_name, children=[t2_name]),
|
||||
]
|
||||
tools = [t1, t2]
|
||||
|
||||
# Make agent state
|
||||
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
for i in range(3):
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
agent_state = setup_agent(client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?")
|
||||
|
||||
assert_sanity_checks(response)
|
||||
messages = response.messages
|
||||
|
||||
assert_invoked_function_call(messages, "first_secret_word")
|
||||
assert_invoked_function_call(messages, "second_secret_word")
|
||||
|
||||
tool_names = [t.name for t in [t1, t2]]
|
||||
tool_names += ["send_message"]
|
||||
for m in messages:
|
||||
if isinstance(m, FunctionCallMessage):
|
||||
# Check that it's equal to the first one
|
||||
assert m.function_call.name == tool_names[0]
|
||||
|
||||
# Pop out first one
|
||||
tool_names = tool_names[1:]
|
||||
|
||||
print(f"Passed iteration {i}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Implement exponential backoff with initial time of 10 seconds
|
||||
if i < 2:
|
||||
backoff_time = 10 * (2 ** i)
|
||||
time.sleep(backoff_time)
|
||||
|
||||
@@ -126,6 +126,7 @@ def test_chat_only_agent(client, mock_e2b_api_key_none):
|
||||
)
|
||||
assert chat_only_agent is not None
|
||||
assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"}
|
||||
assert len(chat_only_agent.tools) == 1
|
||||
|
||||
for message in ["hello", "my name is not chad, my name is swoodily"]:
|
||||
client.send_message(agent_id=chat_only_agent.id, message=message, role="user")
|
||||
|
||||
Reference in New Issue
Block a user