feat: Finish async memory rewriting agent for voice (#1161)

This commit is contained in:
Matthew Zhou
2025-03-03 13:58:06 -08:00
committed by GitHub
parent 19e65bb2c0
commit 353af9aefe
13 changed files with 471 additions and 51 deletions

0
letta/agents/__init__.py Normal file
View File

View File

@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List
import openai
from letta.schemas.letta_message import UserMessage
from letta.schemas.message import Message
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
class BaseAgent(ABC):
"""
Abstract base class for AI agents, handling message management, tool execution,
and context tracking.
"""
def __init__(
self,
agent_id: str,
openai_client: openai.AsyncClient,
message_manager: MessageManager,
agent_manager: AgentManager,
actor: User,
):
self.agent_id = agent_id
self.openai_client = openai_client
self.message_manager = message_manager
self.agent_manager = agent_manager
self.actor = actor
@abstractmethod
async def step(self, input_message: UserMessage) -> List[Message]:
"""
Main execution loop for the agent.
"""
raise NotImplementedError
@abstractmethod
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
"""
Main async execution loop for the agent. Implementations must yield messages as SSE events.
"""
raise NotImplementedError
def pre_process_input_message(self, input_message: UserMessage) -> Any:
"""
Pre-process function to run on the input_message.
"""
return input_message.model_dump()

View File

@@ -0,0 +1,72 @@
from typing import AsyncGenerator, Dict, List
import openai
from letta.agents.base_agent import BaseAgent
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import TextContent, UserMessage
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
class EphemeralAgent(BaseAgent):
"""
A stateless agent (thin wrapper around OpenAI)
# TODO: Extend to more clients
"""
def __init__(
self,
agent_id: str,
openai_client: openai.AsyncClient,
message_manager: MessageManager,
agent_manager: AgentManager,
actor: User,
):
super().__init__(
agent_id=agent_id,
openai_client=openai_client,
message_manager=message_manager,
agent_manager=agent_manager,
actor=actor,
)
async def step(self, input_message: UserMessage) -> List[Message]:
"""
Synchronous method that takes a user's input text and returns a summary from OpenAI.
Returns a list of ephemeral Message objects containing both the user text and the assistant summary.
"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
input_message = self.pre_process_input_message(input_message=input_message)
request = self._build_openai_request([input_message], agent_state)
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
return [
Message(
role=MessageRole.assistant,
content=[TextContent(text=chat_completion.choices[0].message.content.strip())],
)
]
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
openai_request = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=openai_messages,
user=self.actor.id,
max_completion_tokens=agent_state.llm_config.max_tokens,
temperature=agent_state.llm_config.temperature,
)
return openai_request
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
"""
This agent is synchronous-only. If called in an async context, raise an error.
"""
raise NotImplementedError("EphemeralAgent does not support async step.")

View File

@@ -1,10 +1,11 @@
import json
import uuid
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator, Dict, List, Tuple
import openai
from starlette.concurrency import run_in_threadpool
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_agent import EphemeralAgent
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import (
@@ -17,6 +18,7 @@ from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIC
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.block import BlockUpdate
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import (
AssistantMessage,
@@ -28,7 +30,6 @@ from letta.schemas.openai.chat_completion_request import (
UserMessage,
)
from letta.schemas.user import User
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.utils import (
convert_letta_messages_to_openai,
create_assistant_messages_from_openai_response,
@@ -36,14 +37,17 @@ from letta.server.rest_api.utils import (
create_user_message,
)
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.utils import united_diff
logger = get_logger(__name__)
class LowLatencyAgent:
class LowLatencyAgent(BaseAgent):
"""
A function-calling loop for streaming OpenAI responses with tool execution.
This agent:
@@ -58,32 +62,53 @@ class LowLatencyAgent:
openai_client: openai.AsyncClient,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
actor: User,
summarization_mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER,
message_buffer_limit: int = 10,
message_buffer_min: int = 4,
):
self.agent_id = agent_id
self.openai_client = openai_client
super().__init__(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
)
# DB access related fields
self.message_manager = message_manager
self.agent_manager = agent_manager
self.actor = actor
# TODO: Make this more general, factorable
# Summarizer settings
self.block_manager = block_manager
# TODO: This is not guaranteed to exist!
self.summary_block_label = "human"
self.summarizer = Summarizer(
mode=summarization_mode,
summarizer_agent=EphemeralAgent(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
),
message_buffer_limit=message_buffer_limit,
message_buffer_min=message_buffer_min,
)
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
# Internal conversation state
self.optimistic_json_parser = OptimisticJSONParser(strict=True)
self.current_parsed_json_result: Dict[str, Any] = {}
async def step(self, input_message: UserMessage) -> List[Message]:
raise NotImplementedError("LowLatencyAgent does not have a synchronous step implemented currently.")
async def step(self, input_message: Dict[str, str]) -> AsyncGenerator[str, None]:
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
"""
Async generator that yields partial tokens as SSE events, handles tool calls,
and streams error messages if OpenAI API failures occur.
"""
input_message = self.pre_process_input_message(input_message=input_message)
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)]
in_memory_message_history = [input_message]
while True:
# Build context and request
openai_messages = self._build_context_window(in_memory_message_history, agent_state)
# Constantly pull down and integrate memory blocks
in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state)
# Convert Letta messages to OpenAI messages
openai_messages = convert_letta_messages_to_openai(in_context_messages)
openai_messages.extend(in_memory_message_history)
request = self._build_openai_request(openai_messages, agent_state)
# Execute the request
@@ -94,24 +119,19 @@ class LowLatencyAgent:
yield sse
# Process the AI response (buffered messages, tool execution, etc.)
continue_execution = await self.handle_ai_response(
continue_execution = await self._handle_ai_response(
streaming_interface, agent_state, in_memory_message_history, letta_message_db_queue
)
if not continue_execution:
break
# Persist messages to the database asynchronously
await run_in_threadpool(
self.agent_manager.append_to_in_context_messages,
letta_message_db_queue,
agent_id=agent_state.id,
actor=self.actor,
)
# Rebuild context window
await self._rebuild_context_window(in_context_messages, letta_message_db_queue, agent_state)
yield "data: [DONE]\n\n"
async def handle_ai_response(
async def _handle_ai_response(
self,
streaming_interface: OpenAIChatCompletionsStreamingInterface,
agent_state: AgentState,
@@ -194,15 +214,24 @@ class LowLatencyAgent:
# Exit the loop if finish_reason_stop or no tool call occurred
return not streaming_interface.finish_reason_stop
def _build_context_window(self, in_memory_message_history: List[Dict[str, Any]], agent_state: AgentState) -> List[Dict]:
# Build in_context_messages
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state)
async def _rebuild_context_window(
self, in_context_messages: List[Message], letta_message_db_queue: List[Message], agent_state: AgentState
) -> None:
new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor)
# Convert Letta messages to OpenAI messages
openai_messages = convert_letta_messages_to_openai(in_context_messages)
openai_messages.extend(in_memory_message_history)
return openai_messages
# TODO: Make this more general and configurable, less brittle
target_block = next(b for b in agent_state.memory.blocks if b.label == self.summary_block_label)
previous_summary = self.block_manager.get_block_by_id(block_id=target_block.id, actor=self.actor).value
new_in_context_messages, summary_str, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, previous_summary=previous_summary
)
if updated:
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=summary_str), actor=self.actor)
self.agent_manager.set_in_context_messages(
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
@@ -264,7 +293,7 @@ class LowLatencyAgent:
for t in tools
]
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> (str, bool):
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
"""
Executes a tool and returns (result, success_flag).
"""

View File

@@ -54,9 +54,9 @@ DEVELOPMENT_LOGGING = {
"propagate": True, # Let logs bubble up to root
},
"uvicorn": {
"level": "CRITICAL",
"level": "DEBUG",
"handlers": ["console"],
"propagate": False,
"propagate": True,
},
},
}

View File

@@ -6,8 +6,9 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat.completion_create_params import CompletionCreateParams
from letta.agents.low_latency_agent import LowLatencyAgent
from letta.log import get_logger
from letta.low_latency_agent import LowLatencyAgent
from letta.schemas.openai.chat_completions import UserMessage
from letta.server.rest_api.utils import get_letta_server, get_messages_from_completion_request
from letta.settings import model_settings
@@ -44,12 +45,8 @@ async def create_voice_chat_completions(
if agent_id is None:
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
# agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
# if agent_state.llm_config.model_endpoint_type != "openai":
# raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
# Also parse the user's new input
input_message = get_messages_from_completion_request(completion_request)[-1]
input_message = UserMessage(**get_messages_from_completion_request(completion_request)[-1])
# Create OpenAI async client
client = openai.AsyncClient(
@@ -72,8 +69,11 @@ async def create_voice_chat_completions(
openai_client=client,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
actor=actor,
message_buffer_limit=10,
message_buffer_min=4,
)
# Return the streaming generator
return StreamingResponse(agent.step(input_message=input_message), media_type="text/event-stream")
return StreamingResponse(agent.step_stream(input_message=input_message), media_type="text/event-stream")

View File

@@ -535,40 +535,40 @@ class AgentManager:
# TODO: This seems kind of silly, why not just update the message?
message = self.message_manager.create_message(message, actor=actor)
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
else:
return agent_state
@enforce_types
def _set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@enforce_types
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self._set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
# TODO: How do we know this?
new_messages = [message_ids[0]] # 0 is system message
return self._set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
@enforce_types
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
message_ids += [m.id for m in messages]
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
@enforce_types
def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState:

View File

View File

@@ -0,0 +1,9 @@
from enum import Enum
class SummarizationMode(str, Enum):
"""
Represents possible modes of summarization for conversation trimming.
"""
STATIC_MESSAGE_BUFFER = "static_message_buffer_mode"

View File

@@ -0,0 +1,102 @@
import json
from json import JSONDecodeError
from typing import List, Tuple
from letta.agents.base_agent import BaseAgent
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import UserMessage
from letta.services.summarizer.enums import SummarizationMode
class Summarizer:
"""
Handles summarization or trimming of conversation messages based on
the specified SummarizationMode. For now, we demonstrate a simple
static buffer approach but leave room for more advanced strategies.
"""
def __init__(self, mode: SummarizationMode, summarizer_agent: BaseAgent, message_buffer_limit: int = 10, message_buffer_min: int = 3):
self.mode = mode
# Need to do validation on this
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
self.summarizer_agent = summarizer_agent
# TODO: Move this to config
self.summary_prefix = "Out of context message summarization:\n"
async def summarize(
self, in_context_messages: List[Message], new_letta_messages: List[Message], previous_summary: str
) -> Tuple[List[Message], str, bool]:
"""
Summarizes or trims in_context_messages according to the chosen mode,
and returns the updated messages plus any optional "summary message".
Args:
in_context_messages: The existing messages in the conversation's context.
new_letta_messages: The newly added Letta messages (just appended).
previous_summary: The previous summary string.
Returns:
(updated_messages, summary_message)
updated_messages: The new context after trimming/summary
summary_message: Optional summarization message that was created
(could be appended to the conversation if desired)
"""
if self.mode == SummarizationMode.STATIC_MESSAGE_BUFFER:
return await self._static_buffer_summarization(in_context_messages, new_letta_messages, previous_summary)
else:
# Fallback or future logic
return in_context_messages, "", False
async def _static_buffer_summarization(
self, in_context_messages: List[Message], new_letta_messages: List[Message], previous_summary: str
) -> Tuple[List[Message], str, bool]:
previous_summary = previous_summary[: len(self.summary_prefix)]
all_in_context_messages = in_context_messages + new_letta_messages
# Only summarize if we exceed `message_buffer_limit`
if len(all_in_context_messages) <= self.message_buffer_limit:
return all_in_context_messages, previous_summary, False
# Aim to trim down to `message_buffer_min`
target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1
# Move the trim index forward until it's at a `MessageRole.user`
while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user:
target_trim_index += 1
# TODO: Assuming system message is always at index 0
updated_in_context_messages = [all_in_context_messages[0]] + all_in_context_messages[target_trim_index:]
out_of_context_messages = all_in_context_messages[:target_trim_index]
formatted_messages = []
for m in out_of_context_messages:
if m.content:
try:
message = json.loads(m.content[0].text).get("message")
except JSONDecodeError:
continue
if message:
formatted_messages.append(f"{m.role.value}: {message}")
# If we didn't trim any messages, return as-is
if not formatted_messages:
return all_in_context_messages, previous_summary, False
# Generate summarization request
summary_request_text = (
"These are messages that are soon to be removed from the context window:\n"
f"{formatted_messages}\n\n"
"This is the current memory:\n"
f"{previous_summary}\n\n"
"Your task is to integrate any relevant updates from the messages into the memory."
"It should be in note-taking format in natural English. You are to return the new, updated memory only."
)
messages = await self.summarizer_agent.step(UserMessage(content=summary_request_text))
current_summary = "\n".join([m.text for m in messages])
current_summary = f"{self.summary_prefix}{current_summary}"
return updated_in_context_messages, current_summary, True

View File

@@ -153,7 +153,7 @@ def _assert_valid_chunk(chunk, idx, chunks):
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["What's the weather in SF?"])
@pytest.mark.parametrize("message", ["How are you?"])
@pytest.mark.parametrize("endpoint", ["v1/voice"])
async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
"""Tests chat completion streaming using the Async OpenAI client."""

View File

@@ -175,7 +175,7 @@ def test_many_messages_performance(client, num_messages):
message_manager.create_many_messages(all_messages, actor=actor)
log_event("Inserted messages into the database")
agent_manager._set_in_context_messages(
agent_manager.set_in_context_messages(
agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user
)
log_event("Updated agent context with messages")

View File

@@ -0,0 +1,157 @@
import json
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
from letta.agents.base_agent import BaseAgent
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import TextContent
from letta.schemas.message import Message
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
# Constants for test parameters
MESSAGE_BUFFER_LIMIT = 10
MESSAGE_BUFFER_MIN = 3
PREVIOUS_SUMMARY = "Previous summary"
SUMMARY_TEXT = "Summarized memory"
@pytest.fixture
def mock_summarizer_agent():
agent = AsyncMock(spec=BaseAgent)
agent.step.return_value = [Message(role=MessageRole.assistant, content=[TextContent(type="text", text=SUMMARY_TEXT)])]
return agent
@pytest.fixture
def messages():
return [
Message(
role=MessageRole.user if i % 2 == 0 else MessageRole.assistant,
content=[TextContent(type="text", text=json.dumps({"message": f"Test message {i}"}))],
created_at=datetime.utcnow(),
)
for i in range(15)
]
@pytest.mark.asyncio
async def test_static_buffer_summarization_no_trim_needed(mock_summarizer_agent, messages):
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=20)
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:5], [], PREVIOUS_SUMMARY)
assert len(updated_messages) == 5
assert summary == PREVIOUS_SUMMARY
assert not updated
@pytest.mark.asyncio
async def test_static_buffer_summarization_trim_needed(mock_summarizer_agent, messages):
summarizer = Summarizer(
SummarizationMode.STATIC_MESSAGE_BUFFER,
mock_summarizer_agent,
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
message_buffer_min=MESSAGE_BUFFER_MIN,
)
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
assert len(updated_messages) == MESSAGE_BUFFER_MIN # Should be trimmed down to min buffer size
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@pytest.mark.asyncio
async def test_static_buffer_summarization_trim_user_message(mock_summarizer_agent, messages):
summarizer = Summarizer(
SummarizationMode.STATIC_MESSAGE_BUFFER,
mock_summarizer_agent,
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
message_buffer_min=MESSAGE_BUFFER_MIN,
)
# Modify messages to ensure a user message is available to trim at the correct index
messages[5].role = MessageRole.user # Ensure a user message exists in trimming range
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@pytest.mark.asyncio
async def test_static_buffer_summarization_no_trim_no_summarization(mock_summarizer_agent, messages):
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=15)
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:8], [], PREVIOUS_SUMMARY)
assert len(updated_messages) == 8
assert summary == PREVIOUS_SUMMARY
assert not updated
mock_summarizer_agent.step.assert_not_called()
@pytest.mark.asyncio
async def test_static_buffer_summarization_json_parsing_failure(mock_summarizer_agent, messages):
summarizer = Summarizer(
SummarizationMode.STATIC_MESSAGE_BUFFER,
mock_summarizer_agent,
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
message_buffer_min=MESSAGE_BUFFER_MIN,
)
# Inject malformed JSON
messages[2].content = [TextContent(type="text", text="malformed json")]
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@pytest.mark.asyncio
async def test_static_buffer_summarization_all_user_messages_trimmed(mock_summarizer_agent, messages):
summarizer = Summarizer(
SummarizationMode.STATIC_MESSAGE_BUFFER,
mock_summarizer_agent,
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
message_buffer_min=MESSAGE_BUFFER_MIN,
)
# Ensure all messages being trimmed are user messages
for i in range(12):
messages[i].role = MessageRole.user
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@pytest.mark.asyncio
async def test_static_buffer_summarization_no_assistant_messages_trimmed(mock_summarizer_agent, messages):
summarizer = Summarizer(
SummarizationMode.STATIC_MESSAGE_BUFFER,
mock_summarizer_agent,
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
message_buffer_min=MESSAGE_BUFFER_MIN,
)
# Ensure all messages being trimmed are assistant messages
for i in range(12):
messages[i].role = MessageRole.assistant
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
# Yeah, so this actually has to end on 1, because we basically can find no user, so we trim everything
assert len(updated_messages) == 1
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()