fix: Fix build request data for OpenAI (#1654)
This commit is contained in:
@@ -21,6 +21,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
||||
@@ -45,17 +46,18 @@ class OpenAIClient(LLMClientBase):
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI API.
|
||||
"""
|
||||
if tools and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if tools and llm_config.put_inner_thoughts_in_kwargs:
|
||||
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
|
||||
# TODO(fix)
|
||||
inner_thoughts_desc = (
|
||||
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in self.llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
)
|
||||
tools = add_inner_thoughts_to_functions(
|
||||
functions=tools,
|
||||
@@ -65,21 +67,21 @@ class OpenAIClient(LLMClientBase):
|
||||
)
|
||||
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=self.llm_config.put_inner_thoughts_in_kwargs))
|
||||
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
if self.llm_config.model:
|
||||
model = self.llm_config.model
|
||||
if llm_config.model:
|
||||
model = llm_config.model
|
||||
else:
|
||||
logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}")
|
||||
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
|
||||
model = None
|
||||
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle):
|
||||
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
@@ -94,11 +96,11 @@ class OpenAIClient(LLMClientBase):
|
||||
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
|
||||
tool_choice=tool_choice,
|
||||
user=str(),
|
||||
max_completion_tokens=self.llm_config.max_tokens,
|
||||
temperature=self.llm_config.temperature,
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=llm_config.temperature,
|
||||
)
|
||||
|
||||
if "inference.memgpt.ai" in self.llm_config.model_endpoint:
|
||||
if "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
# override user id for inference.memgpt.ai
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ async def test_send_llm_batch_request_async_success(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages):
|
||||
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages, mock_agent_llm_config):
|
||||
"""
|
||||
This test verifies that if the keys in the messages and tools mappings do not match,
|
||||
a ValueError is raised.
|
||||
"""
|
||||
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
|
||||
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
|
||||
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools)
|
||||
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools, mock_agent_llm_config)
|
||||
|
||||
Reference in New Issue
Block a user