fix: patch summarizer for google and use new client (#1639)
This commit is contained in:
@@ -376,7 +376,6 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")
|
||||
log_telemetry(self.logger, "_handle_ai_response finish")
|
||||
return response
|
||||
|
||||
except ValueError as ve:
|
||||
if attempt >= empty_response_retry_limit:
|
||||
@@ -393,6 +392,14 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_handle_ai_response finish generic Exception")
|
||||
raise e
|
||||
|
||||
# check if we are going over the context window: this allows for articifial constraints
|
||||
if response.usage.total_tokens > self.agent_state.llm_config.context_window:
|
||||
# trigger summarization
|
||||
log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace")
|
||||
self.summarize_messages_inplace()
|
||||
# return the response
|
||||
return response
|
||||
|
||||
log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
|
||||
raise Exception("Retries exhausted and no valid response received.")
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class AnthropicClient(LLMClientBase):
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||
@@ -146,11 +146,12 @@ class AnthropicClient(LLMClientBase):
|
||||
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
|
||||
|
||||
# Add tool choice
|
||||
data["tool_choice"] = tool_choice
|
||||
if tool_choice:
|
||||
data["tool_choice"] = tool_choice
|
||||
|
||||
# Add inner thoughts kwarg
|
||||
# TODO: Can probably make this more efficient
|
||||
if len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if tools_for_request and len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
tools_with_inner_thoughts = add_inner_thoughts_to_functions(
|
||||
functions=[t.function.model_dump() for t in tools_for_request],
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
@@ -158,7 +159,7 @@ class AnthropicClient(LLMClientBase):
|
||||
)
|
||||
tools_for_request = [Tool(function=f) for f in tools_with_inner_thoughts]
|
||||
|
||||
if len(tools_for_request) > 0:
|
||||
if tools_for_request and len(tools_for_request) > 0:
|
||||
# TODO eventually enable parallel tool use
|
||||
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
|
||||
|
||||
|
||||
@@ -78,9 +78,11 @@ class OpenAIClient(LLMClientBase):
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
else:
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
tool_choice = "required"
|
||||
|
||||
if force_tool_call is not None:
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Callable, Dict, List
|
||||
|
||||
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -9,6 +10,7 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
|
||||
@@ -45,6 +47,7 @@ def _format_summary_history(message_history: List[Message]):
|
||||
return "\n".join([f"{m.role}: {get_message_text(m.content)}" for m in message_history])
|
||||
|
||||
|
||||
@trace_method
|
||||
def summarize_messages(
|
||||
agent_state: AgentState,
|
||||
message_sequence_to_summarize: List[Message],
|
||||
@@ -74,12 +77,25 @@ def summarize_messages(
|
||||
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
||||
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
response = create(
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
user_id=agent_state.created_by_id,
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
# TODO: we can just directly call the LLM here?
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
response = create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
user_id=agent_state.created_by_id,
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
@@ -130,21 +130,6 @@ def test_summarize_many_messages_basic(client, disable_e2b_api_key):
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_large_message_does_not_loop_infinitely(client, disable_e2b_api_key):
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
|
||||
small_context_llm_config.context_window = 2000
|
||||
small_agent_state = client.create_agent(
|
||||
name="super_small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
with pytest.raises(ContextWindowExceededError, match=f"Ran summarizer {summarizer_settings.max_summarizer_retries}"):
|
||||
client.user_message(
|
||||
agent_id=small_agent_state.id,
|
||||
message="hi " * 1000,
|
||||
)
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_messages_inplace(client, agent_state, disable_e2b_api_key):
|
||||
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
|
||||
# First send a few messages (5)
|
||||
|
||||
Reference in New Issue
Block a user