From a75dfd790712d7262168ea896981adeabc64d72c Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 18 Nov 2024 15:21:11 -0800 Subject: [PATCH] fix: context window overflow patch (#2053) --- letta/agent.py | 3 +- letta/llm_api/llm_api_tools.py | 9 ++++++ letta/local_llm/utils.py | 13 ++++++-- letta/services/tool_manager.py | 4 +++ tests/test_summarize.py | 55 +++++++++++++++++++++++++++++++++- tests/test_tools.py | 5 ++++ 6 files changed, 85 insertions(+), 4 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 455d9eed..50264bc7 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1584,7 +1584,8 @@ class Agent(BaseAgent): def count_tokens(self) -> int: """Count the tokens in the current context window""" - return self.get_context_window().context_window_size_current + context_window_breakdown = self.get_context_window() + return context_window_breakdown.context_window_size_current def save_agent(agent: Agent, ms: MetadataStore): diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 3484f720..9a6374b5 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -25,6 +25,7 @@ from letta.local_llm.constants import ( INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, ) +from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ( @@ -123,6 +124,14 @@ def create( """Return response to chat completion with backoff""" from letta.utils import printd + # Count the tokens first, if there's an overflow exit early by throwing an error up the stack + # NOTE: we want to include a specific substring in the error message to trigger summarization + messages_oai_format = [m.to_openai_dict() for m in messages] + prompt_tokens = num_tokens_from_messages(messages=messages_oai_format, model=llm_config.model) + function_tokens = num_tokens_from_functions(functions=functions, model=llm_config.model) if functions else 0 + if prompt_tokens + function_tokens > llm_config.context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens + function_tokens} > {llm_config.context_window} tokens)") + if not model_settings: from letta.settings import model_settings diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index cc3f0bc1..8b91f4b3 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -94,7 +94,10 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): num_tokens = 0 for function in functions: function_tokens = len(encoding.encode(function["name"])) - function_tokens += len(encoding.encode(function["description"])) + if function["description"]: + function_tokens += len(encoding.encode(function["description"])) + else: + raise ValueError(f"Function {function['name']} has no description, function: {function}") if "parameters" in function: parameters = function["parameters"] @@ -229,7 +232,13 @@ def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int: # num_tokens += len(encoding.encode(value["arguments"])) else: - num_tokens += len(encoding.encode(value)) + if value is None: + # raise ValueError(f"Message has null value: {key} with value: {value} - message={message}") + warnings.warn(f"Message has null value: {key} with value: {value} - message={message}") + else: + if not isinstance(value, str): + raise ValueError(f"Message has non-string value: {key} with value: {value} - message={message}") + num_tokens += len(encoding.encode(value)) if key == "name": num_tokens += tokens_per_name diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index c60b8ee1..4a705e8d 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -70,6 +70,10 @@ class ToolManager: pydantic_tool.organization_id = actor.organization_id tool_data = pydantic_tool.model_dump() tool = ToolModel(**tool_data) + # The description is most likely auto-generated via the json_schema, + # so copy it over into the top-level description field + if tool.description is None: + tool.description = tool.json_schema.get("description", None) tool.create(session, actor=actor) return tool.to_pydantic() diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 5be183f5..dd58311e 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -1,8 +1,11 @@ import uuid +from typing import List from letta import create_client +from letta.client.client import LocalClient from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message from .utils import wipe_config @@ -33,7 +36,7 @@ def create_test_agent(): agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) -def test_summarize(): +def test_summarize_messages_inplace(): """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" global client global agent_obj @@ -73,3 +76,53 @@ def test_summarize(): agent_obj.summarize_messages_inplace() print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}") # response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize") + + +def test_auto_summarize(): + """Test that the summarizer triggers by itself""" + client = create_client() + client.set_default_llm_config(LLMConfig.default_config("gpt-4")) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + + small_context_llm_config = LLMConfig.default_config("gpt-4") + # default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens + SMALL_CONTEXT_WINDOW = 3000 + small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW + + agent_state = client.create_agent( + name="small_context_agent", + llm_config=small_context_llm_config, + ) + + try: + + def summarize_message_exists(messages: List[Message]) -> bool: + for message in messages: + if message.text and "have been hidden from view due to conversation memory constraints" in message.text: + return True + return False + + MAX_ATTEMPTS = 5 + message_count = 0 + while True: + + # send a message + response = client.user_message( + agent_id=agent_state.id, + message="What is the meaning of life?", + ) + message_count += 1 + + # check if the summarize message is inside the messages + assert isinstance(client, LocalClient), "Test only works with LocalClient" + agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) + if summarize_message_exists(agent_obj._messages): + # We found a summarize message + print(f"Summarize message found after {message_count} messages") + break + + if message_count > MAX_ATTEMPTS: + raise Exception(f"Summarize message not found after {message_count} messages") + + finally: + client.delete_agent(agent_state.id) diff --git a/tests/test_tools.py b/tests/test_tools.py index f7e9464c..124520ec 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -83,6 +83,8 @@ def test_create_tool(client: Union[LocalClient, RESTClient]): def print_tool(message: str): """ + Example tool that prints a message + Args: message (str): The message to print. @@ -110,6 +112,7 @@ def test_create_tool(client: Union[LocalClient, RESTClient]): assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}" # create agent with tool + assert tool.name is not None, "Expected tool name to be set" agent_state = client.create_agent(tools=[tool.name]) # Send message without error @@ -121,6 +124,8 @@ def test_create_agent_tool(client): def core_memory_clear(self: "Agent"): """ + Clear the core memory of the agent + Args: agent (Agent): The agent to delete from memory.