feat: Simplify arguments for LLM clients (#1536)
This commit is contained in:
@@ -313,9 +313,7 @@ class Agent(BaseAgent):
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
tools=allowed_functions,
|
||||
tool_call=function_call,
|
||||
stream=stream,
|
||||
first_message=first_message,
|
||||
force_tool_call=force_tool_call,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -56,7 +56,6 @@ class AnthropicClient(LLMClientBase):
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
prefix_fill = True
|
||||
|
||||
@@ -36,7 +36,7 @@ class GoogleAIClient(LLMClientBase):
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
@@ -38,12 +38,12 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
request_data = super().build_request_data(messages, tools, tool_call)
|
||||
request_data = super().build_request_data(messages, tools, force_tool_call)
|
||||
request_data["config"] = request_data.pop("generation_config")
|
||||
request_data["config"]["tools"] = request_data.pop("tools")
|
||||
|
||||
|
||||
@@ -32,9 +32,7 @@ class LLMClientBase:
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
tool_call: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
first_message: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
|
||||
"""
|
||||
@@ -42,7 +40,7 @@ class LLMClientBase:
|
||||
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, tools, tool_call)
|
||||
request_data = self.build_request_data(messages, tools, force_tool_call)
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
@@ -60,9 +58,7 @@ class LLMClientBase:
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
tool_call: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
first_message: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
|
||||
"""
|
||||
@@ -70,7 +66,7 @@ class LLMClientBase:
|
||||
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, tools, tool_call, force_tool_call)
|
||||
request_data = self.build_request_data(messages, tools, force_tool_call)
|
||||
response_data = {}
|
||||
|
||||
try:
|
||||
@@ -90,7 +86,6 @@ class LLMClientBase:
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
|
||||
@@ -46,7 +46,6 @@ class OpenAIClient(LLMClientBase):
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
tool_call: Optional[str] = None, # Note: OpenAI uses tool_choice
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -76,25 +75,22 @@ class OpenAIClient(LLMClientBase):
|
||||
logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}")
|
||||
model = None
|
||||
|
||||
if tool_call is None and tools is not None and len(tools) > 0:
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (
|
||||
self.llm_config.handle and "vllm" in self.llm_config.handle
|
||||
):
|
||||
tool_call = "auto" # TODO change to "required" once proxy supports it
|
||||
else:
|
||||
tool_call = "required"
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
else:
|
||||
tool_choice = "required"
|
||||
|
||||
if tool_call not in ["none", "auto", "required"]:
|
||||
tool_call = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=tool_call))
|
||||
if force_tool_call is not None:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
|
||||
data = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=openai_message_list,
|
||||
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
|
||||
tool_choice=tool_call,
|
||||
tool_choice=tool_choice,
|
||||
user=str(),
|
||||
max_completion_tokens=self.llm_config.max_tokens,
|
||||
temperature=self.llm_config.temperature,
|
||||
|
||||
Reference in New Issue
Block a user