fix: context window overflow patch (#2053)
This commit is contained in:
@@ -1584,7 +1584,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def count_tokens(self) -> int:
|
||||
"""Count the tokens in the current context window"""
|
||||
return self.get_context_window().context_window_size_current
|
||||
context_window_breakdown = self.get_context_window()
|
||||
return context_window_breakdown.context_window_size_current
|
||||
|
||||
|
||||
def save_agent(agent: Agent, ms: MetadataStore):
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.local_llm.constants import (
|
||||
INNER_THOUGHTS_KWARG,
|
||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
)
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
@@ -123,6 +124,14 @@ def create(
|
||||
"""Return response to chat completion with backoff"""
|
||||
from letta.utils import printd
|
||||
|
||||
# Count the tokens first, if there's an overflow exit early by throwing an error up the stack
|
||||
# NOTE: we want to include a specific substring in the error message to trigger summarization
|
||||
messages_oai_format = [m.to_openai_dict() for m in messages]
|
||||
prompt_tokens = num_tokens_from_messages(messages=messages_oai_format, model=llm_config.model)
|
||||
function_tokens = num_tokens_from_functions(functions=functions, model=llm_config.model) if functions else 0
|
||||
if prompt_tokens + function_tokens > llm_config.context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens + function_tokens} > {llm_config.context_window} tokens)")
|
||||
|
||||
if not model_settings:
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
@@ -94,7 +94,10 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
|
||||
num_tokens = 0
|
||||
for function in functions:
|
||||
function_tokens = len(encoding.encode(function["name"]))
|
||||
function_tokens += len(encoding.encode(function["description"]))
|
||||
if function["description"]:
|
||||
function_tokens += len(encoding.encode(function["description"]))
|
||||
else:
|
||||
raise ValueError(f"Function {function['name']} has no description, function: {function}")
|
||||
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
@@ -229,7 +232,13 @@ def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int:
|
||||
# num_tokens += len(encoding.encode(value["arguments"]))
|
||||
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if value is None:
|
||||
# raise ValueError(f"Message has null value: {key} with value: {value} - message={message}")
|
||||
warnings.warn(f"Message has null value: {key} with value: {value} - message={message}")
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"Message has non-string value: {key} with value: {value} - message={message}")
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
@@ -70,6 +70,10 @@ class ToolManager:
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
tool_data = pydantic_tool.model_dump()
|
||||
tool = ToolModel(**tool_data)
|
||||
# The description is most likely auto-generated via the json_schema,
|
||||
# so copy it over into the top-level description field
|
||||
if tool.description is None:
|
||||
tool.description = tool.json_schema.get("description", None)
|
||||
tool.create(session, actor=actor)
|
||||
|
||||
return tool.to_pydantic()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
|
||||
from .utils import wipe_config
|
||||
|
||||
@@ -33,7 +36,7 @@ def create_test_agent():
|
||||
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_summarize():
|
||||
def test_summarize_messages_inplace():
|
||||
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
|
||||
global client
|
||||
global agent_obj
|
||||
@@ -73,3 +76,53 @@ def test_summarize():
|
||||
agent_obj.summarize_messages_inplace()
|
||||
print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}")
|
||||
# response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize")
|
||||
|
||||
|
||||
def test_auto_summarize():
|
||||
"""Test that the summarizer triggers by itself"""
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4")
|
||||
# default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens
|
||||
SMALL_CONTEXT_WINDOW = 3000
|
||||
small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW
|
||||
|
||||
agent_state = client.create_agent(
|
||||
name="small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
def summarize_message_exists(messages: List[Message]) -> bool:
|
||||
for message in messages:
|
||||
if message.text and "have been hidden from view due to conversation memory constraints" in message.text:
|
||||
return True
|
||||
return False
|
||||
|
||||
MAX_ATTEMPTS = 5
|
||||
message_count = 0
|
||||
while True:
|
||||
|
||||
# send a message
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="What is the meaning of life?",
|
||||
)
|
||||
message_count += 1
|
||||
|
||||
# check if the summarize message is inside the messages
|
||||
assert isinstance(client, LocalClient), "Test only works with LocalClient"
|
||||
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id)
|
||||
if summarize_message_exists(agent_obj._messages):
|
||||
# We found a summarize message
|
||||
print(f"Summarize message found after {message_count} messages")
|
||||
break
|
||||
|
||||
if message_count > MAX_ATTEMPTS:
|
||||
raise Exception(f"Summarize message not found after {message_count} messages")
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
@@ -83,6 +83,8 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
Example tool that prints a message
|
||||
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
@@ -110,6 +112,7 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}"
|
||||
|
||||
# create agent with tool
|
||||
assert tool.name is not None, "Expected tool name to be set"
|
||||
agent_state = client.create_agent(tools=[tool.name])
|
||||
|
||||
# Send message without error
|
||||
@@ -121,6 +124,8 @@ def test_create_agent_tool(client):
|
||||
|
||||
def core_memory_clear(self: "Agent"):
|
||||
"""
|
||||
Clear the core memory of the agent
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent to delete from memory.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user