diff --git a/letta/agent.py b/letta/agent.py index d09c1828..ed76ad30 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -313,9 +313,7 @@ class Agent(BaseAgent): response = llm_client.send_llm_request( messages=message_sequence, tools=allowed_functions, - tool_call=function_call, stream=stream, - first_message=first_message, force_tool_call=force_tool_call, ) else: diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index ee73c09f..3e4867b0 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -56,7 +56,6 @@ class AnthropicClient(LLMClientBase): self, messages: List[PydanticMessage], tools: List[dict], - tool_call: Optional[str], force_tool_call: Optional[str] = None, ) -> dict: prefix_fill = True diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index e42dbdbf..cf6818bc 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -36,7 +36,7 @@ class GoogleAIClient(LLMClientBase): self, messages: List[PydanticMessage], tools: List[dict], - tool_call: Optional[str], + force_tool_call: Optional[str] = None, ) -> dict: """ Constructs a request object in the expected data format for this client. diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 1c703249..56237b80 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -38,12 +38,12 @@ class GoogleVertexClient(GoogleAIClient): self, messages: List[PydanticMessage], tools: List[dict], - tool_call: Optional[str], + force_tool_call: Optional[str] = None, ) -> dict: """ Constructs a request object in the expected data format for this client. """ - request_data = super().build_request_data(messages, tools, tool_call) + request_data = super().build_request_data(messages, tools, force_tool_call) request_data["config"] = request_data.pop("generation_config") request_data["config"]["tools"] = request_data.pop("tools") diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 31f68ef8..710983e2 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -32,9 +32,7 @@ class LLMClientBase: self, messages: List[Message], tools: Optional[List[dict]] = None, # TODO: change to Tool object - tool_call: Optional[str] = None, stream: bool = False, - first_message: bool = False, force_tool_call: Optional[str] = None, ) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]: """ @@ -42,7 +40,7 @@ class LLMClientBase: If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, tools, tool_call) + request_data = self.build_request_data(messages, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) @@ -60,9 +58,7 @@ class LLMClientBase: self, messages: List[Message], tools: Optional[List[dict]] = None, # TODO: change to Tool object - tool_call: Optional[str] = None, stream: bool = False, - first_message: bool = False, force_tool_call: Optional[str] = None, ) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]: """ @@ -70,7 +66,7 @@ class LLMClientBase: If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, tools, tool_call, force_tool_call) + request_data = self.build_request_data(messages, tools, force_tool_call) response_data = {} try: @@ -90,7 +86,6 @@ class LLMClientBase: self, messages: List[Message], tools: List[dict], - tool_call: Optional[str], force_tool_call: Optional[str] = None, ) -> dict: """ diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 3ef1b188..0d996977 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -46,7 +46,6 @@ class OpenAIClient(LLMClientBase): self, messages: List[PydanticMessage], tools: Optional[List[dict]] = None, # Keep as dict for now as per base class - tool_call: Optional[str] = None, # Note: OpenAI uses tool_choice force_tool_call: Optional[str] = None, ) -> dict: """ @@ -76,25 +75,22 @@ class OpenAIClient(LLMClientBase): logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}") model = None - if tool_call is None and tools is not None and len(tools) > 0: - # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice - # TODO(matt) move into LLMConfig - # TODO: This vllm checking is very brittle and is a patch at most - if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or ( - self.llm_config.handle and "vllm" in self.llm_config.handle - ): - tool_call = "auto" # TODO change to "required" once proxy supports it - else: - tool_call = "required" + # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + # TODO(matt) move into LLMConfig + # TODO: This vllm checking is very brittle and is a patch at most + if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle): + tool_choice = "auto" # TODO change to "required" once proxy supports it + else: + tool_choice = "required" - if tool_call not in ["none", "auto", "required"]: - tool_call = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=tool_call)) + if force_tool_call is not None: + tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call)) data = ChatCompletionRequest( model=model, messages=openai_message_list, tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None, - tool_choice=tool_call, + tool_choice=tool_choice, user=str(), max_completion_tokens=self.llm_config.max_tokens, temperature=self.llm_config.temperature,