diff --git a/letta/agents/__init__.py b/letta/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py new file mode 100644 index 00000000..a9638eb6 --- /dev/null +++ b/letta/agents/base_agent.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, List + +import openai + +from letta.schemas.letta_message import UserMessage +from letta.schemas.message import Message +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.message_manager import MessageManager + + +class BaseAgent(ABC): + """ + Abstract base class for AI agents, handling message management, tool execution, + and context tracking. + """ + + def __init__( + self, + agent_id: str, + openai_client: openai.AsyncClient, + message_manager: MessageManager, + agent_manager: AgentManager, + actor: User, + ): + self.agent_id = agent_id + self.openai_client = openai_client + self.message_manager = message_manager + self.agent_manager = agent_manager + self.actor = actor + + @abstractmethod + async def step(self, input_message: UserMessage) -> List[Message]: + """ + Main execution loop for the agent. + """ + raise NotImplementedError + + @abstractmethod + async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]: + """ + Main async execution loop for the agent. Implementations must yield messages as SSE events. + """ + raise NotImplementedError + + def pre_process_input_message(self, input_message: UserMessage) -> Any: + """ + Pre-process function to run on the input_message. + """ + return input_message.model_dump() diff --git a/letta/agents/ephemeral_agent.py b/letta/agents/ephemeral_agent.py new file mode 100644 index 00000000..e12d78b1 --- /dev/null +++ b/letta/agents/ephemeral_agent.py @@ -0,0 +1,72 @@ +from typing import AsyncGenerator, Dict, List + +import openai + +from letta.agents.base_agent import BaseAgent +from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import TextContent, UserMessage +from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_request import ChatCompletionRequest +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.message_manager import MessageManager + + +class EphemeralAgent(BaseAgent): + """ + A stateless agent (thin wrapper around OpenAI) + + # TODO: Extend to more clients + """ + + def __init__( + self, + agent_id: str, + openai_client: openai.AsyncClient, + message_manager: MessageManager, + agent_manager: AgentManager, + actor: User, + ): + super().__init__( + agent_id=agent_id, + openai_client=openai_client, + message_manager=message_manager, + agent_manager=agent_manager, + actor=actor, + ) + + async def step(self, input_message: UserMessage) -> List[Message]: + """ + Synchronous method that takes a user's input text and returns a summary from OpenAI. + Returns a list of ephemeral Message objects containing both the user text and the assistant summary. + """ + agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor) + + input_message = self.pre_process_input_message(input_message=input_message) + request = self._build_openai_request([input_message], agent_state) + + chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True)) + + return [ + Message( + role=MessageRole.assistant, + content=[TextContent(text=chat_completion.choices[0].message.content.strip())], + ) + ] + + def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest: + openai_request = ChatCompletionRequest( + model=agent_state.llm_config.model, + messages=openai_messages, + user=self.actor.id, + max_completion_tokens=agent_state.llm_config.max_tokens, + temperature=agent_state.llm_config.temperature, + ) + return openai_request + + async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]: + """ + This agent is synchronous-only. If called in an async context, raise an error. + """ + raise NotImplementedError("EphemeralAgent does not support async step.") diff --git a/letta/low_latency_agent.py b/letta/agents/low_latency_agent.py similarity index 73% rename from letta/low_latency_agent.py rename to letta/agents/low_latency_agent.py index 4b7e5c82..d5d96f23 100644 --- a/letta/low_latency_agent.py +++ b/letta/agents/low_latency_agent.py @@ -1,10 +1,11 @@ import json import uuid -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator, Dict, List, Tuple import openai -from starlette.concurrency import run_in_threadpool +from letta.agents.base_agent import BaseAgent +from letta.agents.ephemeral_agent import EphemeralAgent from letta.constants import NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.tool_execution_helper import ( @@ -17,6 +18,7 @@ from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIC from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.agent import AgentState +from letta.schemas.block import BlockUpdate from letta.schemas.message import Message, MessageUpdate from letta.schemas.openai.chat_completion_request import ( AssistantMessage, @@ -28,7 +30,6 @@ from letta.schemas.openai.chat_completion_request import ( UserMessage, ) from letta.schemas.user import User -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.server.rest_api.utils import ( convert_letta_messages_to_openai, create_assistant_messages_from_openai_response, @@ -36,14 +37,17 @@ from letta.server.rest_api.utils import ( create_user_message, ) from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.message_manager import MessageManager +from letta.services.summarizer.enums import SummarizationMode +from letta.services.summarizer.summarizer import Summarizer from letta.utils import united_diff logger = get_logger(__name__) -class LowLatencyAgent: +class LowLatencyAgent(BaseAgent): """ A function-calling loop for streaming OpenAI responses with tool execution. This agent: @@ -58,32 +62,53 @@ class LowLatencyAgent: openai_client: openai.AsyncClient, message_manager: MessageManager, agent_manager: AgentManager, + block_manager: BlockManager, actor: User, + summarization_mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER, + message_buffer_limit: int = 10, + message_buffer_min: int = 4, ): - self.agent_id = agent_id - self.openai_client = openai_client + super().__init__( + agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor + ) - # DB access related fields - self.message_manager = message_manager - self.agent_manager = agent_manager - self.actor = actor + # TODO: Make this more general, factorable + # Summarizer settings + self.block_manager = block_manager + # TODO: This is not guaranteed to exist! + self.summary_block_label = "human" + self.summarizer = Summarizer( + mode=summarization_mode, + summarizer_agent=EphemeralAgent( + agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor + ), + message_buffer_limit=message_buffer_limit, + message_buffer_min=message_buffer_min, + ) + self.message_buffer_limit = message_buffer_limit + self.message_buffer_min = message_buffer_min - # Internal conversation state - self.optimistic_json_parser = OptimisticJSONParser(strict=True) - self.current_parsed_json_result: Dict[str, Any] = {} + async def step(self, input_message: UserMessage) -> List[Message]: + raise NotImplementedError("LowLatencyAgent does not have a synchronous step implemented currently.") - async def step(self, input_message: Dict[str, str]) -> AsyncGenerator[str, None]: + async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]: """ Async generator that yields partial tokens as SSE events, handles tool calls, and streams error messages if OpenAI API failures occur. """ + input_message = self.pre_process_input_message(input_message=input_message) agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor) + in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor) letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)] in_memory_message_history = [input_message] while True: - # Build context and request - openai_messages = self._build_context_window(in_memory_message_history, agent_state) + # Constantly pull down and integrate memory blocks + in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state) + + # Convert Letta messages to OpenAI messages + openai_messages = convert_letta_messages_to_openai(in_context_messages) + openai_messages.extend(in_memory_message_history) request = self._build_openai_request(openai_messages, agent_state) # Execute the request @@ -94,24 +119,19 @@ class LowLatencyAgent: yield sse # Process the AI response (buffered messages, tool execution, etc.) - continue_execution = await self.handle_ai_response( + continue_execution = await self._handle_ai_response( streaming_interface, agent_state, in_memory_message_history, letta_message_db_queue ) if not continue_execution: break - # Persist messages to the database asynchronously - await run_in_threadpool( - self.agent_manager.append_to_in_context_messages, - letta_message_db_queue, - agent_id=agent_state.id, - actor=self.actor, - ) + # Rebuild context window + await self._rebuild_context_window(in_context_messages, letta_message_db_queue, agent_state) yield "data: [DONE]\n\n" - async def handle_ai_response( + async def _handle_ai_response( self, streaming_interface: OpenAIChatCompletionsStreamingInterface, agent_state: AgentState, @@ -194,15 +214,24 @@ class LowLatencyAgent: # Exit the loop if finish_reason_stop or no tool call occurred return not streaming_interface.finish_reason_stop - def _build_context_window(self, in_memory_message_history: List[Dict[str, Any]], agent_state: AgentState) -> List[Dict]: - # Build in_context_messages - in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor) - in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state) + async def _rebuild_context_window( + self, in_context_messages: List[Message], letta_message_db_queue: List[Message], agent_state: AgentState + ) -> None: + new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor) - # Convert Letta messages to OpenAI messages - openai_messages = convert_letta_messages_to_openai(in_context_messages) - openai_messages.extend(in_memory_message_history) - return openai_messages + # TODO: Make this more general and configurable, less brittle + target_block = next(b for b in agent_state.memory.blocks if b.label == self.summary_block_label) + previous_summary = self.block_manager.get_block_by_id(block_id=target_block.id, actor=self.actor).value + new_in_context_messages, summary_str, updated = await self.summarizer.summarize( + in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, previous_summary=previous_summary + ) + + if updated: + self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=summary_str), actor=self.actor) + + self.agent_manager.set_in_context_messages( + agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor + ) def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this @@ -264,7 +293,7 @@ class LowLatencyAgent: for t in tools ] - async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> (str, bool): + async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: """ Executes a tool and returns (result, success_flag). """ diff --git a/letta/log.py b/letta/log.py index fbac3830..0d4ad8e1 100644 --- a/letta/log.py +++ b/letta/log.py @@ -54,9 +54,9 @@ DEVELOPMENT_LOGGING = { "propagate": True, # Let logs bubble up to root }, "uvicorn": { - "level": "CRITICAL", + "level": "DEBUG", "handlers": ["console"], - "propagate": False, + "propagate": True, }, }, } diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 1ecbde00..0e8b08c0 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -6,8 +6,9 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException from fastapi.responses import StreamingResponse from openai.types.chat.completion_create_params import CompletionCreateParams +from letta.agents.low_latency_agent import LowLatencyAgent from letta.log import get_logger -from letta.low_latency_agent import LowLatencyAgent +from letta.schemas.openai.chat_completions import UserMessage from letta.server.rest_api.utils import get_letta_server, get_messages_from_completion_request from letta.settings import model_settings @@ -44,12 +45,8 @@ async def create_voice_chat_completions( if agent_id is None: raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") - # agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - # if agent_state.llm_config.model_endpoint_type != "openai": - # raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.") - # Also parse the user's new input - input_message = get_messages_from_completion_request(completion_request)[-1] + input_message = UserMessage(**get_messages_from_completion_request(completion_request)[-1]) # Create OpenAI async client client = openai.AsyncClient( @@ -72,8 +69,11 @@ async def create_voice_chat_completions( openai_client=client, message_manager=server.message_manager, agent_manager=server.agent_manager, + block_manager=server.block_manager, actor=actor, + message_buffer_limit=10, + message_buffer_min=4, ) # Return the streaming generator - return StreamingResponse(agent.step(input_message=input_message), media_type="text/event-stream") + return StreamingResponse(agent.step_stream(input_message=input_message), media_type="text/event-stream") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 07dade40..d57ab21c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -535,40 +535,40 @@ class AgentManager: # TODO: This seems kind of silly, why not just update the message? message = self.message_manager.create_message(message, actor=actor) message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system) - return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) else: return agent_state @enforce_types - def _set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: + def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) @enforce_types def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message - return self._set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) @enforce_types def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids # TODO: How do we know this? new_messages = [message_ids[0]] # 0 is system message - return self._set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) @enforce_types def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids new_messages = self.message_manager.create_many_messages(messages, actor=actor) message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:] - return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) @enforce_types def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: messages = self.message_manager.create_many_messages(messages, actor=actor) message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or [] message_ids += [m.id for m in messages] - return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) @enforce_types def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState: diff --git a/letta/services/summarizer/__init__.py b/letta/services/summarizer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/summarizer/enums.py b/letta/services/summarizer/enums.py new file mode 100644 index 00000000..33c42d65 --- /dev/null +++ b/letta/services/summarizer/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class SummarizationMode(str, Enum): + """ + Represents possible modes of summarization for conversation trimming. + """ + + STATIC_MESSAGE_BUFFER = "static_message_buffer_mode" diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py new file mode 100644 index 00000000..7932aa56 --- /dev/null +++ b/letta/services/summarizer/summarizer.py @@ -0,0 +1,102 @@ +import json +from json import JSONDecodeError +from typing import List, Tuple + +from letta.agents.base_agent import BaseAgent +from letta.schemas.enums import MessageRole +from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_request import UserMessage +from letta.services.summarizer.enums import SummarizationMode + + +class Summarizer: + """ + Handles summarization or trimming of conversation messages based on + the specified SummarizationMode. For now, we demonstrate a simple + static buffer approach but leave room for more advanced strategies. + """ + + def __init__(self, mode: SummarizationMode, summarizer_agent: BaseAgent, message_buffer_limit: int = 10, message_buffer_min: int = 3): + self.mode = mode + + # Need to do validation on this + self.message_buffer_limit = message_buffer_limit + self.message_buffer_min = message_buffer_min + self.summarizer_agent = summarizer_agent + # TODO: Move this to config + self.summary_prefix = "Out of context message summarization:\n" + + async def summarize( + self, in_context_messages: List[Message], new_letta_messages: List[Message], previous_summary: str + ) -> Tuple[List[Message], str, bool]: + """ + Summarizes or trims in_context_messages according to the chosen mode, + and returns the updated messages plus any optional "summary message". + + Args: + in_context_messages: The existing messages in the conversation's context. + new_letta_messages: The newly added Letta messages (just appended). + previous_summary: The previous summary string. + + Returns: + (updated_messages, summary_message) + updated_messages: The new context after trimming/summary + summary_message: Optional summarization message that was created + (could be appended to the conversation if desired) + """ + if self.mode == SummarizationMode.STATIC_MESSAGE_BUFFER: + return await self._static_buffer_summarization(in_context_messages, new_letta_messages, previous_summary) + else: + # Fallback or future logic + return in_context_messages, "", False + + async def _static_buffer_summarization( + self, in_context_messages: List[Message], new_letta_messages: List[Message], previous_summary: str + ) -> Tuple[List[Message], str, bool]: + previous_summary = previous_summary[: len(self.summary_prefix)] + all_in_context_messages = in_context_messages + new_letta_messages + + # Only summarize if we exceed `message_buffer_limit` + if len(all_in_context_messages) <= self.message_buffer_limit: + return all_in_context_messages, previous_summary, False + + # Aim to trim down to `message_buffer_min` + target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1 + + # Move the trim index forward until it's at a `MessageRole.user` + while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user: + target_trim_index += 1 + + # TODO: Assuming system message is always at index 0 + updated_in_context_messages = [all_in_context_messages[0]] + all_in_context_messages[target_trim_index:] + out_of_context_messages = all_in_context_messages[:target_trim_index] + + formatted_messages = [] + for m in out_of_context_messages: + if m.content: + try: + message = json.loads(m.content[0].text).get("message") + except JSONDecodeError: + continue + if message: + formatted_messages.append(f"{m.role.value}: {message}") + + # If we didn't trim any messages, return as-is + if not formatted_messages: + return all_in_context_messages, previous_summary, False + + # Generate summarization request + summary_request_text = ( + "These are messages that are soon to be removed from the context window:\n" + f"{formatted_messages}\n\n" + "This is the current memory:\n" + f"{previous_summary}\n\n" + "Your task is to integrate any relevant updates from the messages into the memory." + "It should be in note-taking format in natural English. You are to return the new, updated memory only." + ) + + messages = await self.summarizer_agent.step(UserMessage(content=summary_request_text)) + current_summary = "\n".join([m.text for m in messages]) + current_summary = f"{self.summary_prefix}{current_summary}" + + return updated_in_context_messages, current_summary, True diff --git a/tests/integration_test_chat_completions.py b/tests/integration_test_chat_completions.py index 465b322d..97320849 100644 --- a/tests/integration_test_chat_completions.py +++ b/tests/integration_test_chat_completions.py @@ -153,7 +153,7 @@ def _assert_valid_chunk(chunk, idx, chunks): @pytest.mark.asyncio -@pytest.mark.parametrize("message", ["What's the weather in SF?"]) +@pytest.mark.parametrize("message", ["How are you?"]) @pytest.mark.parametrize("endpoint", ["v1/voice"]) async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint): """Tests chat completion streaming using the Async OpenAI client.""" diff --git a/tests/manual_test_many_messages.py b/tests/manual_test_many_messages.py index c47f32dc..6aaa33bb 100644 --- a/tests/manual_test_many_messages.py +++ b/tests/manual_test_many_messages.py @@ -175,7 +175,7 @@ def test_many_messages_performance(client, num_messages): message_manager.create_many_messages(all_messages, actor=actor) log_event("Inserted messages into the database") - agent_manager._set_in_context_messages( + agent_manager.set_in_context_messages( agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user ) log_event("Updated agent context with messages") diff --git a/tests/test_static_buffer_summarize.py b/tests/test_static_buffer_summarize.py new file mode 100644 index 00000000..0fa18582 --- /dev/null +++ b/tests/test_static_buffer_summarize.py @@ -0,0 +1,157 @@ +import json +from datetime import datetime +from unittest.mock import AsyncMock + +import pytest + +from letta.agents.base_agent import BaseAgent +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import TextContent +from letta.schemas.message import Message +from letta.services.summarizer.enums import SummarizationMode +from letta.services.summarizer.summarizer import Summarizer + +# Constants for test parameters +MESSAGE_BUFFER_LIMIT = 10 +MESSAGE_BUFFER_MIN = 3 +PREVIOUS_SUMMARY = "Previous summary" +SUMMARY_TEXT = "Summarized memory" + + +@pytest.fixture +def mock_summarizer_agent(): + agent = AsyncMock(spec=BaseAgent) + agent.step.return_value = [Message(role=MessageRole.assistant, content=[TextContent(type="text", text=SUMMARY_TEXT)])] + return agent + + +@pytest.fixture +def messages(): + return [ + Message( + role=MessageRole.user if i % 2 == 0 else MessageRole.assistant, + content=[TextContent(type="text", text=json.dumps({"message": f"Test message {i}"}))], + created_at=datetime.utcnow(), + ) + for i in range(15) + ] + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_no_trim_needed(mock_summarizer_agent, messages): + summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=20) + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:5], [], PREVIOUS_SUMMARY) + + assert len(updated_messages) == 5 + assert summary == PREVIOUS_SUMMARY + assert not updated + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_trim_needed(mock_summarizer_agent, messages): + summarizer = Summarizer( + SummarizationMode.STATIC_MESSAGE_BUFFER, + mock_summarizer_agent, + message_buffer_limit=MESSAGE_BUFFER_LIMIT, + message_buffer_min=MESSAGE_BUFFER_MIN, + ) + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + + assert len(updated_messages) == MESSAGE_BUFFER_MIN # Should be trimmed down to min buffer size + assert updated + assert SUMMARY_TEXT in summary + mock_summarizer_agent.step.assert_called() + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_trim_user_message(mock_summarizer_agent, messages): + summarizer = Summarizer( + SummarizationMode.STATIC_MESSAGE_BUFFER, + mock_summarizer_agent, + message_buffer_limit=MESSAGE_BUFFER_LIMIT, + message_buffer_min=MESSAGE_BUFFER_MIN, + ) + + # Modify messages to ensure a user message is available to trim at the correct index + messages[5].role = MessageRole.user # Ensure a user message exists in trimming range + + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + + assert len(updated_messages) == MESSAGE_BUFFER_MIN + assert updated + assert SUMMARY_TEXT in summary + mock_summarizer_agent.step.assert_called() + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_no_trim_no_summarization(mock_summarizer_agent, messages): + summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=15) + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:8], [], PREVIOUS_SUMMARY) + + assert len(updated_messages) == 8 + assert summary == PREVIOUS_SUMMARY + assert not updated + mock_summarizer_agent.step.assert_not_called() + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_json_parsing_failure(mock_summarizer_agent, messages): + summarizer = Summarizer( + SummarizationMode.STATIC_MESSAGE_BUFFER, + mock_summarizer_agent, + message_buffer_limit=MESSAGE_BUFFER_LIMIT, + message_buffer_min=MESSAGE_BUFFER_MIN, + ) + + # Inject malformed JSON + messages[2].content = [TextContent(type="text", text="malformed json")] + + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + + assert len(updated_messages) == MESSAGE_BUFFER_MIN + assert updated + assert SUMMARY_TEXT in summary + mock_summarizer_agent.step.assert_called() + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_all_user_messages_trimmed(mock_summarizer_agent, messages): + summarizer = Summarizer( + SummarizationMode.STATIC_MESSAGE_BUFFER, + mock_summarizer_agent, + message_buffer_limit=MESSAGE_BUFFER_LIMIT, + message_buffer_min=MESSAGE_BUFFER_MIN, + ) + + # Ensure all messages being trimmed are user messages + for i in range(12): + messages[i].role = MessageRole.user + + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + + assert len(updated_messages) == MESSAGE_BUFFER_MIN + assert updated + assert SUMMARY_TEXT in summary + mock_summarizer_agent.step.assert_called() + + +@pytest.mark.asyncio +async def test_static_buffer_summarization_no_assistant_messages_trimmed(mock_summarizer_agent, messages): + summarizer = Summarizer( + SummarizationMode.STATIC_MESSAGE_BUFFER, + mock_summarizer_agent, + message_buffer_limit=MESSAGE_BUFFER_LIMIT, + message_buffer_min=MESSAGE_BUFFER_MIN, + ) + + # Ensure all messages being trimmed are assistant messages + for i in range(12): + messages[i].role = MessageRole.assistant + + updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + + # Yeah, so this actually has to end on 1, because we basically can find no user, so we trim everything + assert len(updated_messages) == 1 + assert updated + assert SUMMARY_TEXT in summary + mock_summarizer_agent.step.assert_called()