From f109259b0b712971f4729405f0ca8aff02566992 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 9 Apr 2025 15:56:54 -0700 Subject: [PATCH] chore: Inject LLM config directly to batch api request func (#1652) --- letta/llm_api/anthropic_client.py | 34 +++++++++++++++++---------- letta/llm_api/google_ai_client.py | 6 +++-- letta/llm_api/google_vertex_client.py | 4 +++- letta/llm_api/llm_client_base.py | 5 ++-- tests/test_llm_clients.py | 19 +++++++++++---- 5 files changed, 46 insertions(+), 22 deletions(-) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index a5afbc08..9c72ba42 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.log import get_logger +from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall @@ -63,6 +64,7 @@ class AnthropicClient(LLMClientBase): self, agent_messages_mapping: Dict[str, List[PydanticMessage]], agent_tools_mapping: Dict[str, List[dict]], + agent_llm_config_mapping: Dict[str, LLMConfig], ) -> BetaMessageBatch: """ Sends a batch request to the Anthropic API using the provided agent messages and tools mappings. @@ -70,6 +72,7 @@ class AnthropicClient(LLMClientBase): Args: agent_messages_mapping: A dict mapping agent_id to their list of PydanticMessages. agent_tools_mapping: A dict mapping agent_id to their list of tool dicts. + agent_llm_config_mapping: A dict mapping agent_id to their LLM config Returns: BetaMessageBatch: The batch response from the Anthropic API. @@ -84,7 +87,11 @@ class AnthropicClient(LLMClientBase): try: requests = { - agent_id: self.build_request_data(messages=agent_messages_mapping[agent_id], tools=agent_tools_mapping[agent_id]) + agent_id: self.build_request_data( + messages=agent_messages_mapping[agent_id], + llm_config=agent_llm_config_mapping[agent_id], + tools=agent_tools_mapping[agent_id], + ) for agent_id in agent_messages_mapping } @@ -114,6 +121,7 @@ class AnthropicClient(LLMClientBase): def build_request_data( self, messages: List[PydanticMessage], + llm_config: LLMConfig, tools: Optional[List[dict]] = None, force_tool_call: Optional[str] = None, ) -> dict: @@ -123,20 +131,20 @@ class AnthropicClient(LLMClientBase): if not self.use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") - if not self.llm_config.max_tokens: + if not llm_config.max_tokens: raise ValueError("Max tokens must be set for anthropic") data = { - "model": self.llm_config.model, - "max_tokens": self.llm_config.max_tokens, - "temperature": self.llm_config.temperature, + "model": llm_config.model, + "max_tokens": llm_config.max_tokens, + "temperature": llm_config.temperature, } # Extended Thinking - if self.llm_config.enable_reasoner: + if llm_config.enable_reasoner: data["thinking"] = { "type": "enabled", - "budget_tokens": self.llm_config.max_reasoning_tokens, + "budget_tokens": llm_config.max_reasoning_tokens, } # `temperature` may only be set to 1 when thinking is enabled. Please consult our documentation at https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking' data["temperature"] = 1.0 @@ -156,13 +164,13 @@ class AnthropicClient(LLMClientBase): tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call] # need to have this setting to be able to put inner thoughts in kwargs - if not self.llm_config.put_inner_thoughts_in_kwargs: + if not llm_config.put_inner_thoughts_in_kwargs: logger.warning( f"Force setting put_inner_thoughts_in_kwargs to True for Claude because there is a forced tool call: {force_tool_call}" ) - self.llm_config.put_inner_thoughts_in_kwargs = True + llm_config.put_inner_thoughts_in_kwargs = True else: - if self.llm_config.put_inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: # tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls tool_choice = {"type": "any", "disable_parallel_tool_use": True} else: @@ -175,7 +183,7 @@ class AnthropicClient(LLMClientBase): # Add inner thoughts kwarg # TODO: Can probably make this more efficient - if tools_for_request and len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs: + if tools_for_request and len(tools_for_request) > 0 and llm_config.put_inner_thoughts_in_kwargs: tools_with_inner_thoughts = add_inner_thoughts_to_functions( functions=[t.function.model_dump() for t in tools_for_request], inner_thoughts_key=INNER_THOUGHTS_KWARG, @@ -197,7 +205,7 @@ class AnthropicClient(LLMClientBase): data["messages"] = [ m.to_anthropic_dict( inner_thoughts_xml_tag=inner_thoughts_xml_tag, - put_inner_thoughts_in_kwargs=bool(self.llm_config.put_inner_thoughts_in_kwargs), + put_inner_thoughts_in_kwargs=bool(llm_config.put_inner_thoughts_in_kwargs), ) for m in messages[1:] ] @@ -213,7 +221,7 @@ class AnthropicClient(LLMClientBase): # https://docs.anthropic.com/en/api/messages#body-messages # NOTE: cannot prefill with tools for opus: # Your API request included an `assistant` message in the final position, which would pre-fill the `assistant` response. When using tools with "claude-3-opus-20240229" - if prefix_fill and not self.llm_config.put_inner_thoughts_in_kwargs and "opus" not in data["model"]: + if prefix_fill and not llm_config.put_inner_thoughts_in_kwargs and "opus" not in data["model"]: data["messages"].append( # Start the thinking process for the assistant {"role": "assistant", "content": f"<{inner_thoughts_xml_tag}>"}, diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index 730f9b9e..430550be 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -11,6 +11,7 @@ from letta.llm_api.helpers import make_post_request from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens +from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics @@ -36,6 +37,7 @@ class GoogleAIClient(LLMClientBase): def build_request_data( self, messages: List[PydanticMessage], + llm_config: LLMConfig, tools: List[dict], force_tool_call: Optional[str] = None, ) -> dict: @@ -55,8 +57,8 @@ class GoogleAIClient(LLMClientBase): "contents": contents, "tools": tools, "generation_config": { - "temperature": self.llm_config.temperature, - "max_output_tokens": self.llm_config.max_tokens, + "temperature": llm_config.temperature, + "max_output_tokens": llm_config.max_tokens, }, } diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 56237b80..937dbe22 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -9,6 +9,7 @@ from letta.helpers.json_helpers import json_dumps from letta.llm_api.google_ai_client import GoogleAIClient from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens +from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.settings import model_settings @@ -37,13 +38,14 @@ class GoogleVertexClient(GoogleAIClient): def build_request_data( self, messages: List[PydanticMessage], + llm_config: LLMConfig, tools: List[dict], force_tool_call: Optional[str] = None, ) -> dict: """ Constructs a request object in the expected data format for this client. """ - request_data = super().build_request_data(messages, tools, force_tool_call) + request_data = super().build_request_data(messages, self.llm_config, tools, force_tool_call) request_data["config"] = request_data.pop("generation_config") request_data["config"]["tools"] = request_data.pop("tools") diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 4340813b..bc6f5be5 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -39,7 +39,7 @@ class LLMClientBase: If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, tools, force_tool_call) + request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) @@ -65,7 +65,7 @@ class LLMClientBase: If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, tools, force_tool_call) + request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) @@ -88,6 +88,7 @@ class LLMClientBase: def build_request_data( self, messages: List[Message], + llm_config: LLMConfig, tools: List[dict], force_tool_call: Optional[str] = None, ) -> dict: diff --git a/tests/test_llm_clients.py b/tests/test_llm_clients.py index 8caf726f..bcb0b78e 100644 --- a/tests/test_llm_clients.py +++ b/tests/test_llm_clients.py @@ -11,8 +11,8 @@ from letta.schemas.message import Message as PydanticMessage @pytest.fixture -def anthropic_client(): - llm_config = LLMConfig( +def llm_config(): + yield LLMConfig( model="claude-3-7-sonnet-20250219", model_endpoint_type="anthropic", model_endpoint="https://api.anthropic.com/v1", @@ -23,6 +23,10 @@ def anthropic_client(): enable_reasoner=True, max_reasoning_tokens=1024, ) + + +@pytest.fixture +def anthropic_client(llm_config): return AnthropicClient(llm_config=llm_config) @@ -57,8 +61,15 @@ def mock_agent_tools(): } +@pytest.fixture +def mock_agent_llm_config(llm_config): + return {"agent-1": llm_config} + + @pytest.mark.asyncio -async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent_messages, mock_agent_tools, dummy_beta_message_batch): +async def test_send_llm_batch_request_async_success( + anthropic_client, mock_agent_messages, mock_agent_tools, mock_agent_llm_config, dummy_beta_message_batch +): """Test a successful batch request using mocked Anthropic client responses.""" # Patch the _get_anthropic_client method so that it returns a mock client. with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client: @@ -68,7 +79,7 @@ async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent mock_get_client.return_value = mock_client # Call the method under test. - response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools) + response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools, mock_agent_llm_config) # Assert that the response is our dummy response. assert response.id == dummy_beta_message_batch.id