feat: Finish async memory rewriting agent for voice (#1161)
This commit is contained in:
0
letta/agents/__init__.py
Normal file
0
letta/agents/__init__.py
Normal file
51
letta/agents/base_agent.py
Normal file
51
letta/agents/base_agent.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
import openai
|
||||
|
||||
from letta.schemas.letta_message import UserMessage
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""
|
||||
Abstract base class for AI agents, handling message management, tool execution,
|
||||
and context tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
openai_client: openai.AsyncClient,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
actor: User,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.openai_client = openai_client
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.actor = actor
|
||||
|
||||
@abstractmethod
|
||||
async def step(self, input_message: UserMessage) -> List[Message]:
|
||||
"""
|
||||
Main execution loop for the agent.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main async execution loop for the agent. Implementations must yield messages as SSE events.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pre_process_input_message(self, input_message: UserMessage) -> Any:
|
||||
"""
|
||||
Pre-process function to run on the input_message.
|
||||
"""
|
||||
return input_message.model_dump()
|
||||
72
letta/agents/ephemeral_agent.py
Normal file
72
letta/agents/ephemeral_agent.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
import openai
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent, UserMessage
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
|
||||
class EphemeralAgent(BaseAgent):
|
||||
"""
|
||||
A stateless agent (thin wrapper around OpenAI)
|
||||
|
||||
# TODO: Extend to more clients
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
openai_client: openai.AsyncClient,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
actor: User,
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
openai_client=openai_client,
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def step(self, input_message: UserMessage) -> List[Message]:
|
||||
"""
|
||||
Synchronous method that takes a user's input text and returns a summary from OpenAI.
|
||||
Returns a list of ephemeral Message objects containing both the user text and the assistant summary.
|
||||
"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
|
||||
|
||||
input_message = self.pre_process_input_message(input_message=input_message)
|
||||
request = self._build_openai_request([input_message], agent_state)
|
||||
|
||||
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
|
||||
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text=chat_completion.choices[0].message.content.strip())],
|
||||
)
|
||||
]
|
||||
|
||||
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
|
||||
openai_request = ChatCompletionRequest(
|
||||
model=agent_state.llm_config.model,
|
||||
messages=openai_messages,
|
||||
user=self.actor.id,
|
||||
max_completion_tokens=agent_state.llm_config.max_tokens,
|
||||
temperature=agent_state.llm_config.temperature,
|
||||
)
|
||||
return openai_request
|
||||
|
||||
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
This agent is synchronous-only. If called in an async context, raise an error.
|
||||
"""
|
||||
raise NotImplementedError("EphemeralAgent does not support async step.")
|
||||
@@ -1,10 +1,11 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||
|
||||
import openai
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.ephemeral_agent import EphemeralAgent
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import (
|
||||
@@ -17,6 +18,7 @@ from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIC
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
AssistantMessage,
|
||||
@@ -28,7 +30,6 @@ from letta.schemas.openai.chat_completion_request import (
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.utils import (
|
||||
convert_letta_messages_to_openai,
|
||||
create_assistant_messages_from_openai_response,
|
||||
@@ -36,14 +37,17 @@ from letta.server.rest_api.utils import (
|
||||
create_user_message,
|
||||
)
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.utils import united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LowLatencyAgent:
|
||||
class LowLatencyAgent(BaseAgent):
|
||||
"""
|
||||
A function-calling loop for streaming OpenAI responses with tool execution.
|
||||
This agent:
|
||||
@@ -58,32 +62,53 @@ class LowLatencyAgent:
|
||||
openai_client: openai.AsyncClient,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
actor: User,
|
||||
summarization_mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
message_buffer_limit: int = 10,
|
||||
message_buffer_min: int = 4,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.openai_client = openai_client
|
||||
super().__init__(
|
||||
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
||||
)
|
||||
|
||||
# DB access related fields
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.actor = actor
|
||||
# TODO: Make this more general, factorable
|
||||
# Summarizer settings
|
||||
self.block_manager = block_manager
|
||||
# TODO: This is not guaranteed to exist!
|
||||
self.summary_block_label = "human"
|
||||
self.summarizer = Summarizer(
|
||||
mode=summarization_mode,
|
||||
summarizer_agent=EphemeralAgent(
|
||||
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
||||
),
|
||||
message_buffer_limit=message_buffer_limit,
|
||||
message_buffer_min=message_buffer_min,
|
||||
)
|
||||
self.message_buffer_limit = message_buffer_limit
|
||||
self.message_buffer_min = message_buffer_min
|
||||
|
||||
# Internal conversation state
|
||||
self.optimistic_json_parser = OptimisticJSONParser(strict=True)
|
||||
self.current_parsed_json_result: Dict[str, Any] = {}
|
||||
async def step(self, input_message: UserMessage) -> List[Message]:
|
||||
raise NotImplementedError("LowLatencyAgent does not have a synchronous step implemented currently.")
|
||||
|
||||
async def step(self, input_message: Dict[str, str]) -> AsyncGenerator[str, None]:
|
||||
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Async generator that yields partial tokens as SSE events, handles tool calls,
|
||||
and streams error messages if OpenAI API failures occur.
|
||||
"""
|
||||
input_message = self.pre_process_input_message(input_message=input_message)
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
|
||||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)]
|
||||
in_memory_message_history = [input_message]
|
||||
|
||||
while True:
|
||||
# Build context and request
|
||||
openai_messages = self._build_context_window(in_memory_message_history, agent_state)
|
||||
# Constantly pull down and integrate memory blocks
|
||||
in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state)
|
||||
|
||||
# Convert Letta messages to OpenAI messages
|
||||
openai_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
openai_messages.extend(in_memory_message_history)
|
||||
request = self._build_openai_request(openai_messages, agent_state)
|
||||
|
||||
# Execute the request
|
||||
@@ -94,24 +119,19 @@ class LowLatencyAgent:
|
||||
yield sse
|
||||
|
||||
# Process the AI response (buffered messages, tool execution, etc.)
|
||||
continue_execution = await self.handle_ai_response(
|
||||
continue_execution = await self._handle_ai_response(
|
||||
streaming_interface, agent_state, in_memory_message_history, letta_message_db_queue
|
||||
)
|
||||
|
||||
if not continue_execution:
|
||||
break
|
||||
|
||||
# Persist messages to the database asynchronously
|
||||
await run_in_threadpool(
|
||||
self.agent_manager.append_to_in_context_messages,
|
||||
letta_message_db_queue,
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
# Rebuild context window
|
||||
await self._rebuild_context_window(in_context_messages, letta_message_db_queue, agent_state)
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def handle_ai_response(
|
||||
async def _handle_ai_response(
|
||||
self,
|
||||
streaming_interface: OpenAIChatCompletionsStreamingInterface,
|
||||
agent_state: AgentState,
|
||||
@@ -194,15 +214,24 @@ class LowLatencyAgent:
|
||||
# Exit the loop if finish_reason_stop or no tool call occurred
|
||||
return not streaming_interface.finish_reason_stop
|
||||
|
||||
def _build_context_window(self, in_memory_message_history: List[Dict[str, Any]], agent_state: AgentState) -> List[Dict]:
|
||||
# Build in_context_messages
|
||||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state)
|
||||
async def _rebuild_context_window(
|
||||
self, in_context_messages: List[Message], letta_message_db_queue: List[Message], agent_state: AgentState
|
||||
) -> None:
|
||||
new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor)
|
||||
|
||||
# Convert Letta messages to OpenAI messages
|
||||
openai_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
openai_messages.extend(in_memory_message_history)
|
||||
return openai_messages
|
||||
# TODO: Make this more general and configurable, less brittle
|
||||
target_block = next(b for b in agent_state.memory.blocks if b.label == self.summary_block_label)
|
||||
previous_summary = self.block_manager.get_block_by_id(block_id=target_block.id, actor=self.actor).value
|
||||
new_in_context_messages, summary_str, updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, previous_summary=previous_summary
|
||||
)
|
||||
|
||||
if updated:
|
||||
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=summary_str), actor=self.actor)
|
||||
|
||||
self.agent_manager.set_in_context_messages(
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
@@ -264,7 +293,7 @@ class LowLatencyAgent:
|
||||
for t in tools
|
||||
]
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> (str, bool):
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
@@ -54,9 +54,9 @@ DEVELOPMENT_LOGGING = {
|
||||
"propagate": True, # Let logs bubble up to root
|
||||
},
|
||||
"uvicorn": {
|
||||
"level": "CRITICAL",
|
||||
"level": "DEBUG",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
"propagate": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
from letta.agents.low_latency_agent import LowLatencyAgent
|
||||
from letta.log import get_logger
|
||||
from letta.low_latency_agent import LowLatencyAgent
|
||||
from letta.schemas.openai.chat_completions import UserMessage
|
||||
from letta.server.rest_api.utils import get_letta_server, get_messages_from_completion_request
|
||||
from letta.settings import model_settings
|
||||
|
||||
@@ -44,12 +45,8 @@ async def create_voice_chat_completions(
|
||||
if agent_id is None:
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
|
||||
# agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
# if agent_state.llm_config.model_endpoint_type != "openai":
|
||||
# raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
|
||||
|
||||
# Also parse the user's new input
|
||||
input_message = get_messages_from_completion_request(completion_request)[-1]
|
||||
input_message = UserMessage(**get_messages_from_completion_request(completion_request)[-1])
|
||||
|
||||
# Create OpenAI async client
|
||||
client = openai.AsyncClient(
|
||||
@@ -72,8 +69,11 @@ async def create_voice_chat_completions(
|
||||
openai_client=client,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
actor=actor,
|
||||
message_buffer_limit=10,
|
||||
message_buffer_min=4,
|
||||
)
|
||||
|
||||
# Return the streaming generator
|
||||
return StreamingResponse(agent.step(input_message=input_message), media_type="text/event-stream")
|
||||
return StreamingResponse(agent.step_stream(input_message=input_message), media_type="text/event-stream")
|
||||
|
||||
@@ -535,40 +535,40 @@ class AgentManager:
|
||||
# TODO: This seems kind of silly, why not just update the message?
|
||||
message = self.message_manager.create_message(message, actor=actor)
|
||||
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
|
||||
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def _set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||||
return self._set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
# TODO: How do we know this?
|
||||
new_messages = [message_ids[0]] # 0 is system message
|
||||
return self._set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||||
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState:
|
||||
|
||||
0
letta/services/summarizer/__init__.py
Normal file
0
letta/services/summarizer/__init__.py
Normal file
9
letta/services/summarizer/enums.py
Normal file
9
letta/services/summarizer/enums.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SummarizationMode(str, Enum):
|
||||
"""
|
||||
Represents possible modes of summarization for conversation trimming.
|
||||
"""
|
||||
|
||||
STATIC_MESSAGE_BUFFER = "static_message_buffer_mode"
|
||||
102
letta/services/summarizer/summarizer.py
Normal file
102
letta/services/summarizer/summarizer.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import UserMessage
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
|
||||
|
||||
class Summarizer:
|
||||
"""
|
||||
Handles summarization or trimming of conversation messages based on
|
||||
the specified SummarizationMode. For now, we demonstrate a simple
|
||||
static buffer approach but leave room for more advanced strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, mode: SummarizationMode, summarizer_agent: BaseAgent, message_buffer_limit: int = 10, message_buffer_min: int = 3):
|
||||
self.mode = mode
|
||||
|
||||
# Need to do validation on this
|
||||
self.message_buffer_limit = message_buffer_limit
|
||||
self.message_buffer_min = message_buffer_min
|
||||
self.summarizer_agent = summarizer_agent
|
||||
# TODO: Move this to config
|
||||
self.summary_prefix = "Out of context message summarization:\n"
|
||||
|
||||
async def summarize(
|
||||
self, in_context_messages: List[Message], new_letta_messages: List[Message], previous_summary: str
|
||||
) -> Tuple[List[Message], str, bool]:
|
||||
"""
|
||||
Summarizes or trims in_context_messages according to the chosen mode,
|
||||
and returns the updated messages plus any optional "summary message".
|
||||
|
||||
Args:
|
||||
in_context_messages: The existing messages in the conversation's context.
|
||||
new_letta_messages: The newly added Letta messages (just appended).
|
||||
previous_summary: The previous summary string.
|
||||
|
||||
Returns:
|
||||
(updated_messages, summary_message)
|
||||
updated_messages: The new context after trimming/summary
|
||||
summary_message: Optional summarization message that was created
|
||||
(could be appended to the conversation if desired)
|
||||
"""
|
||||
if self.mode == SummarizationMode.STATIC_MESSAGE_BUFFER:
|
||||
return await self._static_buffer_summarization(in_context_messages, new_letta_messages, previous_summary)
|
||||
else:
|
||||
# Fallback or future logic
|
||||
return in_context_messages, "", False
|
||||
|
||||
async def _static_buffer_summarization(
|
||||
self, in_context_messages: List[Message], new_letta_messages: List[Message], previous_summary: str
|
||||
) -> Tuple[List[Message], str, bool]:
|
||||
previous_summary = previous_summary[: len(self.summary_prefix)]
|
||||
all_in_context_messages = in_context_messages + new_letta_messages
|
||||
|
||||
# Only summarize if we exceed `message_buffer_limit`
|
||||
if len(all_in_context_messages) <= self.message_buffer_limit:
|
||||
return all_in_context_messages, previous_summary, False
|
||||
|
||||
# Aim to trim down to `message_buffer_min`
|
||||
target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1
|
||||
|
||||
# Move the trim index forward until it's at a `MessageRole.user`
|
||||
while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user:
|
||||
target_trim_index += 1
|
||||
|
||||
# TODO: Assuming system message is always at index 0
|
||||
updated_in_context_messages = [all_in_context_messages[0]] + all_in_context_messages[target_trim_index:]
|
||||
out_of_context_messages = all_in_context_messages[:target_trim_index]
|
||||
|
||||
formatted_messages = []
|
||||
for m in out_of_context_messages:
|
||||
if m.content:
|
||||
try:
|
||||
message = json.loads(m.content[0].text).get("message")
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
if message:
|
||||
formatted_messages.append(f"{m.role.value}: {message}")
|
||||
|
||||
# If we didn't trim any messages, return as-is
|
||||
if not formatted_messages:
|
||||
return all_in_context_messages, previous_summary, False
|
||||
|
||||
# Generate summarization request
|
||||
summary_request_text = (
|
||||
"These are messages that are soon to be removed from the context window:\n"
|
||||
f"{formatted_messages}\n\n"
|
||||
"This is the current memory:\n"
|
||||
f"{previous_summary}\n\n"
|
||||
"Your task is to integrate any relevant updates from the messages into the memory."
|
||||
"It should be in note-taking format in natural English. You are to return the new, updated memory only."
|
||||
)
|
||||
|
||||
messages = await self.summarizer_agent.step(UserMessage(content=summary_request_text))
|
||||
current_summary = "\n".join([m.text for m in messages])
|
||||
current_summary = f"{self.summary_prefix}{current_summary}"
|
||||
|
||||
return updated_in_context_messages, current_summary, True
|
||||
@@ -153,7 +153,7 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["What's the weather in SF?"])
|
||||
@pytest.mark.parametrize("message", ["How are you?"])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice"])
|
||||
async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
|
||||
@@ -175,7 +175,7 @@ def test_many_messages_performance(client, num_messages):
|
||||
message_manager.create_many_messages(all_messages, actor=actor)
|
||||
log_event("Inserted messages into the database")
|
||||
|
||||
agent_manager._set_in_context_messages(
|
||||
agent_manager.set_in_context_messages(
|
||||
agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user
|
||||
)
|
||||
log_event("Updated agent context with messages")
|
||||
|
||||
157
tests/test_static_buffer_summarize.py
Normal file
157
tests/test_static_buffer_summarize.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
|
||||
# Constants for test parameters
|
||||
MESSAGE_BUFFER_LIMIT = 10
|
||||
MESSAGE_BUFFER_MIN = 3
|
||||
PREVIOUS_SUMMARY = "Previous summary"
|
||||
SUMMARY_TEXT = "Summarized memory"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_summarizer_agent():
|
||||
agent = AsyncMock(spec=BaseAgent)
|
||||
agent.step.return_value = [Message(role=MessageRole.assistant, content=[TextContent(type="text", text=SUMMARY_TEXT)])]
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages():
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.user if i % 2 == 0 else MessageRole.assistant,
|
||||
content=[TextContent(type="text", text=json.dumps({"message": f"Test message {i}"}))],
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
for i in range(15)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_no_trim_needed(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=20)
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:5], [], PREVIOUS_SUMMARY)
|
||||
|
||||
assert len(updated_messages) == 5
|
||||
assert summary == PREVIOUS_SUMMARY
|
||||
assert not updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_trim_needed(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
mock_summarizer_agent,
|
||||
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
|
||||
message_buffer_min=MESSAGE_BUFFER_MIN,
|
||||
)
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN # Should be trimmed down to min buffer size
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_trim_user_message(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
mock_summarizer_agent,
|
||||
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
|
||||
message_buffer_min=MESSAGE_BUFFER_MIN,
|
||||
)
|
||||
|
||||
# Modify messages to ensure a user message is available to trim at the correct index
|
||||
messages[5].role = MessageRole.user # Ensure a user message exists in trimming range
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_no_trim_no_summarization(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=15)
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:8], [], PREVIOUS_SUMMARY)
|
||||
|
||||
assert len(updated_messages) == 8
|
||||
assert summary == PREVIOUS_SUMMARY
|
||||
assert not updated
|
||||
mock_summarizer_agent.step.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_json_parsing_failure(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
mock_summarizer_agent,
|
||||
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
|
||||
message_buffer_min=MESSAGE_BUFFER_MIN,
|
||||
)
|
||||
|
||||
# Inject malformed JSON
|
||||
messages[2].content = [TextContent(type="text", text="malformed json")]
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_all_user_messages_trimmed(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
mock_summarizer_agent,
|
||||
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
|
||||
message_buffer_min=MESSAGE_BUFFER_MIN,
|
||||
)
|
||||
|
||||
# Ensure all messages being trimmed are user messages
|
||||
for i in range(12):
|
||||
messages[i].role = MessageRole.user
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
|
||||
assert len(updated_messages) == MESSAGE_BUFFER_MIN
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_buffer_summarization_no_assistant_messages_trimmed(mock_summarizer_agent, messages):
|
||||
summarizer = Summarizer(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
mock_summarizer_agent,
|
||||
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
|
||||
message_buffer_min=MESSAGE_BUFFER_MIN,
|
||||
)
|
||||
|
||||
# Ensure all messages being trimmed are assistant messages
|
||||
for i in range(12):
|
||||
messages[i].role = MessageRole.assistant
|
||||
|
||||
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
|
||||
|
||||
# Yeah, so this actually has to end on 1, because we basically can find no user, so we trim everything
|
||||
assert len(updated_messages) == 1
|
||||
assert updated
|
||||
assert SUMMARY_TEXT in summary
|
||||
mock_summarizer_agent.step.assert_called()
|
||||
Reference in New Issue
Block a user