fix: context window overflow patch (#2053)

This commit is contained in:
Charles Packer
2024-11-18 15:21:11 -08:00
committed by GitHub
parent f57dc28552
commit a75dfd7907
6 changed files with 85 additions and 4 deletions

View File

@@ -1584,7 +1584,8 @@ class Agent(BaseAgent):
def count_tokens(self) -> int:
"""Count the tokens in the current context window"""
return self.get_context_window().context_window_size_current
context_window_breakdown = self.get_context_window()
return context_window_breakdown.context_window_size_current
def save_agent(agent: Agent, ms: MetadataStore):

View File

@@ -25,6 +25,7 @@ from letta.local_llm.constants import (
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import (
@@ -123,6 +124,14 @@ def create(
"""Return response to chat completion with backoff"""
from letta.utils import printd
# Count the tokens first, if there's an overflow exit early by throwing an error up the stack
# NOTE: we want to include a specific substring in the error message to trigger summarization
messages_oai_format = [m.to_openai_dict() for m in messages]
prompt_tokens = num_tokens_from_messages(messages=messages_oai_format, model=llm_config.model)
function_tokens = num_tokens_from_functions(functions=functions, model=llm_config.model) if functions else 0
if prompt_tokens + function_tokens > llm_config.context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens + function_tokens} > {llm_config.context_window} tokens)")
if not model_settings:
from letta.settings import model_settings

View File

@@ -94,7 +94,10 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function["name"]))
function_tokens += len(encoding.encode(function["description"]))
if function["description"]:
function_tokens += len(encoding.encode(function["description"]))
else:
raise ValueError(f"Function {function['name']} has no description, function: {function}")
if "parameters" in function:
parameters = function["parameters"]
@@ -229,7 +232,13 @@ def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int:
# num_tokens += len(encoding.encode(value["arguments"]))
else:
num_tokens += len(encoding.encode(value))
if value is None:
# raise ValueError(f"Message has null value: {key} with value: {value} - message={message}")
warnings.warn(f"Message has null value: {key} with value: {value} - message={message}")
else:
if not isinstance(value, str):
raise ValueError(f"Message has non-string value: {key} with value: {value} - message={message}")
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name

View File

@@ -70,6 +70,10 @@ class ToolManager:
pydantic_tool.organization_id = actor.organization_id
tool_data = pydantic_tool.model_dump()
tool = ToolModel(**tool_data)
# The description is most likely auto-generated via the json_schema,
# so copy it over into the top-level description field
if tool.description is None:
tool.description = tool.json_schema.get("description", None)
tool.create(session, actor=actor)
return tool.to_pydantic()

View File

@@ -1,8 +1,11 @@
import uuid
from typing import List
from letta import create_client
from letta.client.client import LocalClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from .utils import wipe_config
@@ -33,7 +36,7 @@ def create_test_agent():
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id)
def test_summarize():
def test_summarize_messages_inplace():
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
global client
global agent_obj
@@ -73,3 +76,53 @@ def test_summarize():
agent_obj.summarize_messages_inplace()
print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}")
# response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize")
def test_auto_summarize():
"""Test that the summarizer triggers by itself"""
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
small_context_llm_config = LLMConfig.default_config("gpt-4")
# default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens
SMALL_CONTEXT_WINDOW = 3000
small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW
agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
)
try:
def summarize_message_exists(messages: List[Message]) -> bool:
for message in messages:
if message.text and "have been hidden from view due to conversation memory constraints" in message.text:
return True
return False
MAX_ATTEMPTS = 5
message_count = 0
while True:
# send a message
response = client.user_message(
agent_id=agent_state.id,
message="What is the meaning of life?",
)
message_count += 1
# check if the summarize message is inside the messages
assert isinstance(client, LocalClient), "Test only works with LocalClient"
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id)
if summarize_message_exists(agent_obj._messages):
# We found a summarize message
print(f"Summarize message found after {message_count} messages")
break
if message_count > MAX_ATTEMPTS:
raise Exception(f"Summarize message not found after {message_count} messages")
finally:
client.delete_agent(agent_state.id)

View File

@@ -83,6 +83,8 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
def print_tool(message: str):
"""
Example tool that prints a message
Args:
message (str): The message to print.
@@ -110,6 +112,7 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}"
# create agent with tool
assert tool.name is not None, "Expected tool name to be set"
agent_state = client.create_agent(tools=[tool.name])
# Send message without error
@@ -121,6 +124,8 @@ def test_create_agent_tool(client):
def core_memory_clear(self: "Agent"):
"""
Clear the core memory of the agent
Args:
agent (Agent): The agent to delete from memory.