feat: expand summarizer providers
This commit is contained in:
@@ -67,7 +67,8 @@ class BaseAgent(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pre_process_input_message(self, input_messages: List[MessageCreate]) -> Any:
|
||||
@staticmethod
|
||||
def pre_process_input_message(input_messages: List[MessageCreate]) -> Any:
|
||||
"""
|
||||
Pre-process function to run on the input_message.
|
||||
"""
|
||||
|
||||
@@ -1,27 +1,28 @@
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.prompts.gpt_system import get_system_text
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EphemeralSummaryAgent(BaseAgent):
|
||||
"""
|
||||
A stateless summarization agent (thin wrapper around OpenAI)
|
||||
|
||||
# TODO: Extend to more clients
|
||||
A stateless summarization agent that utilizes the caller's LLM client to summarize the conversation.
|
||||
TODO (cliandy): allow the summarizer to use another llm_config from the main agent maybe?
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,7 +36,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
openai_client=AsyncOpenAI(),
|
||||
openai_client=None,
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
@@ -65,17 +66,33 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
input_message = input_messages[0]
|
||||
input_message.content[0].text += f"\n\n--- Previous Summary ---\n{block.value}\n"
|
||||
|
||||
openai_messages = self.pre_process_input_message(input_messages=input_messages)
|
||||
request = self._build_openai_request(openai_messages)
|
||||
# Gets the LLMCLient based on the calling agent's LLM Config
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# TODO: Extend to generic client
|
||||
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
|
||||
summary = chat_completion.choices[0].message.content.strip()
|
||||
system_message_create = MessageCreate(
|
||||
role=MessageRole.system,
|
||||
content=[TextContent(text=get_system_text("summary_system_prompt"))],
|
||||
)
|
||||
messages = convert_message_creates_to_messages(
|
||||
message_creates=[system_message_create] + input_messages,
|
||||
agent_id=self.agent_id,
|
||||
timezone=agent_state.timezone,
|
||||
)
|
||||
|
||||
request_data = llm_client.build_request_data(messages, agent_state.llm_config, tools=[])
|
||||
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
|
||||
summary = response.choices[0].message.content.strip()
|
||||
|
||||
await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor)
|
||||
|
||||
print(block)
|
||||
print(summary)
|
||||
logger.debug("block:", block)
|
||||
logger.debug("summary:", summary)
|
||||
|
||||
return [
|
||||
Message(
|
||||
@@ -84,22 +101,5 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
)
|
||||
]
|
||||
|
||||
def _build_openai_request(self, openai_messages: List[Dict]) -> ChatCompletionRequest:
|
||||
current_dir = Path(__file__).parent
|
||||
file_path = current_dir / "prompts" / "summary_system_prompt.txt"
|
||||
with open(file_path, "r") as file:
|
||||
system = file.read()
|
||||
|
||||
system_message = [{"role": "system", "content": system}]
|
||||
|
||||
openai_request = ChatCompletionRequest(
|
||||
model="gpt-4o",
|
||||
messages=system_message + openai_messages,
|
||||
user=self.actor.id,
|
||||
max_completion_tokens=4096,
|
||||
temperature=0.7,
|
||||
)
|
||||
return openai_request
|
||||
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]:
|
||||
raise NotImplementedError("EphemeralAgent does not support async step.")
|
||||
|
||||
@@ -58,11 +58,15 @@ from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import model_settings
|
||||
from letta.settings import model_settings, summarizer_settings
|
||||
from letta.system import package_function_response
|
||||
from letta.types import JsonDict
|
||||
from letta.utils import log_telemetry, validate_function_response
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary"
|
||||
|
||||
|
||||
class LettaAgent(BaseAgent):
|
||||
|
||||
@@ -77,11 +81,11 @@ class LettaAgent(BaseAgent):
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
summary_block_label: str = "conversation_summary",
|
||||
message_buffer_limit: int = 60, # TODO: Make this configurable
|
||||
message_buffer_min: int = 15, # TODO: Make this configurable
|
||||
enable_summarization: bool = True, # TODO: Make this configurable
|
||||
max_summarization_retries: int = 3, # TODO: Make this configurable
|
||||
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
|
||||
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
|
||||
message_buffer_min: int = summarizer_settings.message_buffer_min,
|
||||
enable_summarization: bool = summarizer_settings.enable_summarization,
|
||||
max_summarization_retries: int = summarizer_settings.max_summarization_retries,
|
||||
):
|
||||
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
||||
|
||||
@@ -117,7 +121,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
|
||||
self.summarizer = Summarizer(
|
||||
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
mode=SummarizationMode(summarizer_settings.mode),
|
||||
summarizer_agent=self.summarization_agent,
|
||||
# TODO: Make this configurable
|
||||
message_buffer_limit=message_buffer_limit,
|
||||
|
||||
@@ -302,6 +302,7 @@ class OpenAIProvider(Provider):
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
allowed_types = ["gpt-4", "o1", "o3", "o4"]
|
||||
# NOTE: o1-mini and o1-preview do not support tool calling
|
||||
# NOTE: o1-mini does not support system messages
|
||||
# NOTE: o1-pro is only available in Responses API
|
||||
disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"]
|
||||
skip = True
|
||||
|
||||
@@ -90,7 +90,7 @@ async def add_feedback(
|
||||
|
||||
|
||||
@router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id")
|
||||
def update_step_transaction_id(
|
||||
async def update_step_transaction_id(
|
||||
step_id: str,
|
||||
transaction_id: str,
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -102,6 +102,6 @@ def update_step_transaction_id(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
|
||||
return await server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
@@ -173,7 +173,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
|
||||
async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
|
||||
"""Update the transaction ID for a step.
|
||||
|
||||
Args:
|
||||
@@ -187,15 +187,15 @@ class StepManager:
|
||||
Raises:
|
||||
NoResultFound: If the step does not exist
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
step = session.get(StepModel, step_id)
|
||||
async with db_registry.async_session() as session:
|
||||
step = await session.get(StepModel, step_id)
|
||||
if not step:
|
||||
raise NoResultFound(f"Step with id {step_id} does not exist")
|
||||
if step.organization_id != actor.organization_id:
|
||||
raise Exception("Unauthorized")
|
||||
|
||||
step.tid = transaction_id
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return step.to_pydantic()
|
||||
|
||||
def _verify_job_access(
|
||||
@@ -226,8 +226,8 @@ class StepManager:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
@staticmethod
|
||||
async def _verify_job_access_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.templates.template_helper import render_template
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -123,30 +124,13 @@ class Summarizer:
|
||||
formatted_evicted_messages = [f"{i}. {msg}" for (i, msg) in enumerate(formatted_evicted_messages)]
|
||||
formatted_in_context_messages = [f"{i + offset}. {msg}" for (i, msg) in enumerate(formatted_in_context_messages)]
|
||||
|
||||
evicted_messages_str = "\n".join(formatted_evicted_messages)
|
||||
in_context_messages_str = "\n".join(formatted_in_context_messages)
|
||||
# Base prompt
|
||||
prompt_header = (
|
||||
f"You’re a memory-recall helper for an AI that can only keep the last {retain_count} messages. "
|
||||
"Scan the conversation history, focusing on messages about to drop out of that window, "
|
||||
"and write crisp notes that capture any important facts or insights about the conversation history so they aren’t lost."
|
||||
summary_request_text = render_template(
|
||||
"summary_request_text.j2",
|
||||
retain_count=retain_count,
|
||||
evicted_messages=formatted_evicted_messages,
|
||||
in_context_messages=formatted_in_context_messages,
|
||||
)
|
||||
|
||||
# Sections
|
||||
evicted_section = f"\n\n(Older) Evicted Messages:\n{evicted_messages_str}" if evicted_messages_str.strip() else ""
|
||||
in_context_section = ""
|
||||
|
||||
if retain_count > 0 and in_context_messages_str.strip():
|
||||
in_context_section = f"\n\n(Newer) In-Context Messages:\n{in_context_messages_str}"
|
||||
elif retain_count == 0:
|
||||
prompt_header = (
|
||||
"You’re a memory-recall helper for an AI that is about to forget all prior messages. "
|
||||
"Scan the conversation history and write crisp notes that capture any important facts or insights about the conversation history."
|
||||
)
|
||||
|
||||
# Compose final prompt
|
||||
summary_request_text = prompt_header + evicted_section + in_context_section
|
||||
|
||||
# Fire-and-forget the summarization task
|
||||
self.fire_and_forget(
|
||||
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import AliasChoices, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from letta.local_llm.constants import DEFAULT_WRAPPER_NAME
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
|
||||
|
||||
class ToolSettings(BaseSettings):
|
||||
@@ -38,6 +39,13 @@ class ToolSettings(BaseSettings):
|
||||
class SummarizerSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_summarizer_", extra="ignore")
|
||||
|
||||
mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER
|
||||
message_buffer_limit: int = 60
|
||||
message_buffer_min: int = 15
|
||||
enable_summarization: bool = True
|
||||
max_summarization_retries: int = 3
|
||||
|
||||
# TODO(cliandy): the below settings are tied to old summarization and should be deprecated or moved
|
||||
# Controls if we should evict all messages
|
||||
# TODO: Can refactor this into an enum if we have a bunch of different kinds of summarizers
|
||||
evict_all_messages: bool = False
|
||||
|
||||
19
letta/templates/summary_request_text.j2
Normal file
19
letta/templates/summary_request_text.j2
Normal file
@@ -0,0 +1,19 @@
|
||||
{% if retain_count == 0 %}
|
||||
You’re a memory-recall helper for an AI that is about to forget all prior messages. Scan the conversation history and write crisp notes that capture any important facts or insights about the conversation history.
|
||||
{% else %}
|
||||
You’re a memory-recall helper for an AI that can only keep the last {{ retain_count }} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost.
|
||||
{% endif %}
|
||||
|
||||
{% if evicted_messages %}
|
||||
(Older) Evicted Messages:
|
||||
{% for item in evicted_messages %}
|
||||
{{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if retain_count > 0 and in_context_messages %}
|
||||
(Newer) In-Context Messages:
|
||||
{% for item in in_context_messages %}
|
||||
{{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"context_window": 128000,
|
||||
"model": "o1-mini",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null,
|
||||
"temperature": 1.0
|
||||
}
|
||||
@@ -103,7 +103,6 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"openai-o1.json",
|
||||
"openai-o1-mini.json",
|
||||
"openai-o3.json",
|
||||
"openai-o4-mini.json",
|
||||
"azure-gpt-4o-mini.json",
|
||||
@@ -116,7 +115,7 @@ all_configs = [
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
"ollama.json",
|
||||
# "ollama.json", # TODO (cliandy): enable this in ollama testing
|
||||
]
|
||||
|
||||
|
||||
@@ -1215,7 +1214,7 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
|
||||
new_llm_config = llm_config.model_dump()
|
||||
new_llm_config["context_window"] = 3000
|
||||
pinned_context_window_llm_config = LLMConfig(**new_llm_config)
|
||||
|
||||
print("::LLM::", llm_config, new_llm_config)
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
temp_agent_state = client.agents.create(
|
||||
include_base_tools=False,
|
||||
|
||||
@@ -54,6 +54,7 @@ def actor(server, org_id):
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
|
||||
|
||||
@pytest.mark.flaky(max_runs=3)
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_sleeptime_group_chat(server, actor):
|
||||
# 0. Refresh base tools
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
pythonpath = /letta
|
||||
testpaths = /tests
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
filterwarnings =
|
||||
ignore::pytest.PytestRemovedIn9Warning
|
||||
# suppresses the warnings we see with the event_loop fixture
|
||||
|
||||
@@ -443,6 +443,7 @@ def test_function_return_limit(disable_e2b_api_key, client: LettaSDKClient, agen
|
||||
assert "function output was truncated " in res
|
||||
|
||||
|
||||
@pytest.mark.flaky(max_runs=3)
|
||||
def test_function_always_error(client: LettaSDKClient, agent: AgentState):
|
||||
"""Test to see if function that errors works correctly"""
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from letta.services.summarizer.summarizer import Summarizer
|
||||
# Constants for test parameters
|
||||
MESSAGE_BUFFER_LIMIT = 10
|
||||
MESSAGE_BUFFER_MIN = 3
|
||||
PREVIOUS_SUMMARY = "Previous summary"
|
||||
SUMMARY_TEXT = "Summarized memory"
|
||||
|
||||
|
||||
@@ -22,6 +21,7 @@ SUMMARY_TEXT = "Summarized memory"
|
||||
def mock_summarizer_agent():
|
||||
agent = AsyncMock(spec=BaseAgent)
|
||||
agent.step.return_value = [Message(role=MessageRole.assistant, content=[TextContent(type="text", text=SUMMARY_TEXT)])]
|
||||
agent.update_message_transcript = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@@ -40,10 +40,9 @@ def messages():
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_no_trim_needed(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=20)
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:5], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:5], [])
|
||||
|
||||
assert len(updated_messages) == 5
|
||||
assert summary == PREVIOUS_SUMMARY
|
||||
assert not updated
|
||||
|
||||
|
||||
@@ -55,11 +54,10 @@ async def test_static_buffer_summarization_trim_needed(mock_summarizer_agent, me
|
||||
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
|
||||
message_buffer_min=MESSAGE_BUFFER_MIN,
|
||||
)
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN # Should be trimmed down to min buffer size
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@@ -75,21 +73,19 @@ async def test_static_buffer_summarization_trim_user_message(mock_summarizer_age
|
||||
# Modify messages to ensure a user message is available to trim at the correct index
|
||||
messages[5].role = MessageRole.user # Ensure a user message exists in trimming range
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_no_trim_no_summarization(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=15)
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:8], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:8], [])
|
||||
|
||||
assert len(updated_messages) == 8
|
||||
assert summary == PREVIOUS_SUMMARY
|
||||
assert not updated
|
||||
mock_summarizer_agent.step.assert_not_called()
|
||||
|
||||
@@ -106,11 +102,10 @@ async def test_static_buffer_summarization_json_parsing_failure(mock_summarizer_
|
||||
# Inject malformed JSON
|
||||
messages[2].content = [TextContent(type="text", text="malformed json")]
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@@ -127,11 +122,10 @@ async def test_static_buffer_summarization_all_user_messages_trimmed(mock_summar
|
||||
for i in range(12):
|
||||
messages[i].role = MessageRole.user
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN + 1
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@@ -148,10 +142,9 @@ async def test_static_buffer_summarization_no_assistant_messages_trimmed(mock_su
|
||||
for i in range(12):
|
||||
messages[i].role = MessageRole.assistant
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
|
||||
|
||||
# Yeah, so this actually has to end on 1, because we basically can find no user, so we trim everything
|
||||
assert len(updated_messages) == 1
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
Reference in New Issue
Block a user