diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 856c795f..8df6c1cb 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -21,6 +21,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.log import get_logger +from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall @@ -45,17 +46,18 @@ class OpenAIClient(LLMClientBase): def build_request_data( self, messages: List[PydanticMessage], + llm_config: LLMConfig, tools: Optional[List[dict]] = None, # Keep as dict for now as per base class force_tool_call: Optional[str] = None, ) -> dict: """ Constructs a request object in the expected data format for the OpenAI API. """ - if tools and self.llm_config.put_inner_thoughts_in_kwargs: + if tools and llm_config.put_inner_thoughts_in_kwargs: # Special case for LM Studio backend since it needs extra guidance to force out the thoughts first # TODO(fix) inner_thoughts_desc = ( - INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in self.llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION + INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION ) tools = add_inner_thoughts_to_functions( functions=tools, @@ -65,21 +67,21 @@ class OpenAIClient(LLMClientBase): ) openai_message_list = [ - cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=self.llm_config.put_inner_thoughts_in_kwargs)) + cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs)) for m in messages ] - if self.llm_config.model: - model = self.llm_config.model + if llm_config.model: + model = llm_config.model else: - logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}") + logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}") model = None # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # TODO(matt) move into LLMConfig # TODO: This vllm checking is very brittle and is a patch at most tool_choice = None - if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle): + if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle): tool_choice = "auto" # TODO change to "required" once proxy supports it elif tools: # only set if tools is non-Null @@ -94,11 +96,11 @@ class OpenAIClient(LLMClientBase): tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None, tool_choice=tool_choice, user=str(), - max_completion_tokens=self.llm_config.max_tokens, - temperature=self.llm_config.temperature, + max_completion_tokens=llm_config.max_tokens, + temperature=llm_config.temperature, ) - if "inference.memgpt.ai" in self.llm_config.model_endpoint: + if "inference.memgpt.ai" in llm_config.model_endpoint: # override user id for inference.memgpt.ai import uuid diff --git a/tests/test_llm_clients.py b/tests/test_llm_clients.py index bcb0b78e..b5987b15 100644 --- a/tests/test_llm_clients.py +++ b/tests/test_llm_clients.py @@ -91,11 +91,11 @@ async def test_send_llm_batch_request_async_success( @pytest.mark.asyncio -async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages): +async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages, mock_agent_llm_config): """ This test verifies that if the keys in the messages and tools mappings do not match, a ValueError is raised. """ mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping. with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."): - await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools) + await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools, mock_agent_llm_config)