fix: Fix build request data for OpenAI (#1654)

This commit is contained in:
Matthew Zhou
2025-04-09 16:31:20 -07:00
committed by GitHub
parent 39e3d3760e
commit 74e299a05f
2 changed files with 14 additions and 12 deletions

View File

@@ -21,6 +21,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
@@ -45,17 +46,18 @@ class OpenAIClient(LLMClientBase):
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
force_tool_call: Optional[str] = None,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI API.
"""
if tools and self.llm_config.put_inner_thoughts_in_kwargs:
if tools and llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
# TODO(fix)
inner_thoughts_desc = (
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in self.llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
)
tools = add_inner_thoughts_to_functions(
functions=tools,
@@ -65,21 +67,21 @@ class OpenAIClient(LLMClientBase):
)
openai_message_list = [
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=self.llm_config.put_inner_thoughts_in_kwargs))
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs))
for m in messages
]
if self.llm_config.model:
model = self.llm_config.model
if llm_config.model:
model = llm_config.model
else:
logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}")
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
model = None
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle):
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle):
tool_choice = "auto" # TODO change to "required" once proxy supports it
elif tools:
# only set if tools is non-Null
@@ -94,11 +96,11 @@ class OpenAIClient(LLMClientBase):
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
tool_choice=tool_choice,
user=str(),
max_completion_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature,
)
if "inference.memgpt.ai" in self.llm_config.model_endpoint:
if "inference.memgpt.ai" in llm_config.model_endpoint:
# override user id for inference.memgpt.ai
import uuid

View File

@@ -91,11 +91,11 @@ async def test_send_llm_batch_request_async_success(
@pytest.mark.asyncio
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages):
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages, mock_agent_llm_config):
"""
This test verifies that if the keys in the messages and tools mappings do not match,
a ValueError is raised.
"""
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools)
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools, mock_agent_llm_config)