feat: expand summarizer providers

This commit is contained in:
Andy Li
2025-07-01 15:07:38 -07:00
committed by GitHub
parent efca9d8ea0
commit 32f2cf17b5
16 changed files with 102 additions and 98 deletions

View File

@@ -67,7 +67,8 @@ class BaseAgent(ABC):
"""
raise NotImplementedError
def pre_process_input_message(self, input_messages: List[MessageCreate]) -> Any:
@staticmethod
def pre_process_input_message(input_messages: List[MessageCreate]) -> Any:
"""
Pre-process function to run on the input_message.
"""

View File

@@ -1,27 +1,28 @@
from pathlib import Path
from typing import AsyncGenerator, Dict, List
from openai import AsyncOpenAI
from typing import AsyncGenerator, List
from letta.agents.base_agent import BaseAgent
from letta.constants import DEFAULT_MAX_STEPS
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.prompts.gpt_system import get_system_text
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
logger = get_logger(__name__)
class EphemeralSummaryAgent(BaseAgent):
"""
A stateless summarization agent (thin wrapper around OpenAI)
# TODO: Extend to more clients
A stateless summarization agent that utilizes the caller's LLM client to summarize the conversation.
TODO (cliandy): allow the summarizer to use another llm_config from the main agent maybe?
"""
def __init__(
@@ -35,7 +36,7 @@ class EphemeralSummaryAgent(BaseAgent):
):
super().__init__(
agent_id=agent_id,
openai_client=AsyncOpenAI(),
openai_client=None,
message_manager=message_manager,
agent_manager=agent_manager,
actor=actor,
@@ -65,17 +66,33 @@ class EphemeralSummaryAgent(BaseAgent):
input_message = input_messages[0]
input_message.content[0].text += f"\n\n--- Previous Summary ---\n{block.value}\n"
openai_messages = self.pre_process_input_message(input_messages=input_messages)
request = self._build_openai_request(openai_messages)
# Gets the LLMCLient based on the calling agent's LLM Config
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
# TODO: Extend to generic client
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
summary = chat_completion.choices[0].message.content.strip()
system_message_create = MessageCreate(
role=MessageRole.system,
content=[TextContent(text=get_system_text("summary_system_prompt"))],
)
messages = convert_message_creates_to_messages(
message_creates=[system_message_create] + input_messages,
agent_id=self.agent_id,
timezone=agent_state.timezone,
)
request_data = llm_client.build_request_data(messages, agent_state.llm_config, tools=[])
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
response = llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
summary = response.choices[0].message.content.strip()
await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor)
print(block)
print(summary)
logger.debug("block:", block)
logger.debug("summary:", summary)
return [
Message(
@@ -84,22 +101,5 @@ class EphemeralSummaryAgent(BaseAgent):
)
]
def _build_openai_request(self, openai_messages: List[Dict]) -> ChatCompletionRequest:
current_dir = Path(__file__).parent
file_path = current_dir / "prompts" / "summary_system_prompt.txt"
with open(file_path, "r") as file:
system = file.read()
system_message = [{"role": "system", "content": system}]
openai_request = ChatCompletionRequest(
model="gpt-4o",
messages=system_message + openai_messages,
user=self.actor.id,
max_completion_tokens=4096,
temperature=0.7,
)
return openai_request
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]:
raise NotImplementedError("EphemeralAgent does not support async step.")

View File

@@ -58,11 +58,15 @@ from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import model_settings
from letta.settings import model_settings, summarizer_settings
from letta.system import package_function_response
from letta.types import JsonDict
from letta.utils import log_telemetry, validate_function_response
logger = get_logger(__name__)
DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary"
class LettaAgent(BaseAgent):
@@ -77,11 +81,11 @@ class LettaAgent(BaseAgent):
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
summary_block_label: str = "conversation_summary",
message_buffer_limit: int = 60, # TODO: Make this configurable
message_buffer_min: int = 15, # TODO: Make this configurable
enable_summarization: bool = True, # TODO: Make this configurable
max_summarization_retries: int = 3, # TODO: Make this configurable
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
message_buffer_min: int = summarizer_settings.message_buffer_min,
enable_summarization: bool = summarizer_settings.enable_summarization,
max_summarization_retries: int = summarizer_settings.max_summarization_retries,
):
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
@@ -117,7 +121,7 @@ class LettaAgent(BaseAgent):
)
self.summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
mode=SummarizationMode(summarizer_settings.mode),
summarizer_agent=self.summarization_agent,
# TODO: Make this configurable
message_buffer_limit=message_buffer_limit,

View File

@@ -302,6 +302,7 @@ class OpenAIProvider(Provider):
if self.base_url == "https://api.openai.com/v1":
allowed_types = ["gpt-4", "o1", "o3", "o4"]
# NOTE: o1-mini and o1-preview do not support tool calling
# NOTE: o1-mini does not support system messages
# NOTE: o1-pro is only available in Responses API
disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"]
skip = True

View File

@@ -90,7 +90,7 @@ async def add_feedback(
@router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id")
def update_step_transaction_id(
async def update_step_transaction_id(
step_id: str,
transaction_id: str,
actor_id: Optional[str] = Header(None, alias="user_id"),
@@ -102,6 +102,6 @@ def update_step_transaction_id(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
return await server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
except NoResultFound:
raise HTTPException(status_code=404, detail="Step not found")

View File

@@ -173,7 +173,7 @@ class StepManager:
@enforce_types
@trace_method
def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
"""Update the transaction ID for a step.
Args:
@@ -187,15 +187,15 @@ class StepManager:
Raises:
NoResultFound: If the step does not exist
"""
with db_registry.session() as session:
step = session.get(StepModel, step_id)
async with db_registry.async_session() as session:
step = await session.get(StepModel, step_id)
if not step:
raise NoResultFound(f"Step with id {step_id} does not exist")
if step.organization_id != actor.organization_id:
raise Exception("Unauthorized")
step.tid = transaction_id
session.commit()
await session.commit()
return step.to_pydantic()
def _verify_job_access(
@@ -226,8 +226,8 @@ class StepManager:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job
@staticmethod
async def _verify_job_access_async(
self,
session: AsyncSession,
job_id: str,
actor: PydanticUser,

View File

@@ -11,6 +11,7 @@ from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.services.summarizer.enums import SummarizationMode
from letta.templates.template_helper import render_template
logger = get_logger(__name__)
@@ -123,30 +124,13 @@ class Summarizer:
formatted_evicted_messages = [f"{i}. {msg}" for (i, msg) in enumerate(formatted_evicted_messages)]
formatted_in_context_messages = [f"{i + offset}. {msg}" for (i, msg) in enumerate(formatted_in_context_messages)]
evicted_messages_str = "\n".join(formatted_evicted_messages)
in_context_messages_str = "\n".join(formatted_in_context_messages)
# Base prompt
prompt_header = (
f"Youre a memory-recall helper for an AI that can only keep the last {retain_count} messages. "
"Scan the conversation history, focusing on messages about to drop out of that window, "
"and write crisp notes that capture any important facts or insights about the conversation history so they arent lost."
summary_request_text = render_template(
"summary_request_text.j2",
retain_count=retain_count,
evicted_messages=formatted_evicted_messages,
in_context_messages=formatted_in_context_messages,
)
# Sections
evicted_section = f"\n\n(Older) Evicted Messages:\n{evicted_messages_str}" if evicted_messages_str.strip() else ""
in_context_section = ""
if retain_count > 0 and in_context_messages_str.strip():
in_context_section = f"\n\n(Newer) In-Context Messages:\n{in_context_messages_str}"
elif retain_count == 0:
prompt_header = (
"Youre a memory-recall helper for an AI that is about to forget all prior messages. "
"Scan the conversation history and write crisp notes that capture any important facts or insights about the conversation history."
)
# Compose final prompt
summary_request_text = prompt_header + evicted_section + in_context_section
# Fire-and-forget the summarization task
self.fire_and_forget(
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])

View File

@@ -6,6 +6,7 @@ from pydantic import AliasChoices, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from letta.local_llm.constants import DEFAULT_WRAPPER_NAME
from letta.services.summarizer.enums import SummarizationMode
class ToolSettings(BaseSettings):
@@ -38,6 +39,13 @@ class ToolSettings(BaseSettings):
class SummarizerSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_summarizer_", extra="ignore")
mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER
message_buffer_limit: int = 60
message_buffer_min: int = 15
enable_summarization: bool = True
max_summarization_retries: int = 3
# TODO(cliandy): the below settings are tied to old summarization and should be deprecated or moved
# Controls if we should evict all messages
# TODO: Can refactor this into an enum if we have a bunch of different kinds of summarizers
evict_all_messages: bool = False

View File

@@ -0,0 +1,19 @@
{% if retain_count == 0 %}
Youre a memory-recall helper for an AI that is about to forget all prior messages. Scan the conversation history and write crisp notes that capture any important facts or insights about the conversation history.
{% else %}
Youre a memory-recall helper for an AI that can only keep the last {{ retain_count }} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
{% endif %}
{% if evicted_messages %}
(Older) Evicted Messages:
{% for item in evicted_messages %}
{{ item }}
{% endfor %}
{% endif %}
{% if retain_count > 0 and in_context_messages %}
(Newer) In-Context Messages:
{% for item in in_context_messages %}
{{ item }}
{% endfor %}
{% endif %}

View File

@@ -1,8 +0,0 @@
{
"context_window": 128000,
"model": "o1-mini",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null,
"temperature": 1.0
}

View File

@@ -103,7 +103,6 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
all_configs = [
"openai-gpt-4o-mini.json",
"openai-o1.json",
"openai-o1-mini.json",
"openai-o3.json",
"openai-o4-mini.json",
"azure-gpt-4o-mini.json",
@@ -116,7 +115,7 @@ all_configs = [
"gemini-2.5-flash-vertex.json",
"gemini-2.5-pro-vertex.json",
"together-qwen-2.5-72b-instruct.json",
"ollama.json",
# "ollama.json", # TODO (cliandy): enable this in ollama testing
]
@@ -1215,7 +1214,7 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
new_llm_config = llm_config.model_dump()
new_llm_config["context_window"] = 3000
pinned_context_window_llm_config = LLMConfig(**new_llm_config)
print("::LLM::", llm_config, new_llm_config)
send_message_tool = client.tools.list(name="send_message")[0]
temp_agent_state = client.agents.create(
include_base_tools=False,

View File

@@ -54,6 +54,7 @@ def actor(server, org_id):
server.user_manager.delete_user_by_id(user.id)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.asyncio(loop_scope="module")
async def test_sleeptime_group_chat(server, actor):
# 0. Refresh base tools

View File

@@ -2,6 +2,7 @@
pythonpath = /letta
testpaths = /tests
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
filterwarnings =
ignore::pytest.PytestRemovedIn9Warning
# suppresses the warnings we see with the event_loop fixture

View File

@@ -443,6 +443,7 @@ def test_function_return_limit(disable_e2b_api_key, client: LettaSDKClient, agen
assert "function output was truncated " in res
@pytest.mark.flaky(max_runs=3)
def test_function_always_error(client: LettaSDKClient, agent: AgentState):
"""Test to see if function that errors works correctly"""

View File

@@ -14,7 +14,6 @@ from letta.services.summarizer.summarizer import Summarizer
# Constants for test parameters
MESSAGE_BUFFER_LIMIT = 10
MESSAGE_BUFFER_MIN = 3
PREVIOUS_SUMMARY = "Previous summary"
SUMMARY_TEXT = "Summarized memory"
@@ -22,6 +21,7 @@ SUMMARY_TEXT = "Summarized memory"
def mock_summarizer_agent():
agent = AsyncMock(spec=BaseAgent)
agent.step.return_value = [Message(role=MessageRole.assistant, content=[TextContent(type="text", text=SUMMARY_TEXT)])]
agent.update_message_transcript = AsyncMock()
return agent
@@ -40,10 +40,9 @@ def messages():
@pytest.mark.asyncio
async def test_static_buffer_summarization_no_trim_needed(mock_summarizer_agent, messages):
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=20)
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:5], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:5], [])
assert len(updated_messages) == 5
assert summary == PREVIOUS_SUMMARY
assert not updated
@@ -55,11 +54,10 @@ async def test_static_buffer_summarization_trim_needed(mock_summarizer_agent, me
message_buffer_limit=MESSAGE_BUFFER_LIMIT,
message_buffer_min=MESSAGE_BUFFER_MIN,
)
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
assert len(updated_messages) == MESSAGE_BUFFER_MIN # Should be trimmed down to min buffer size
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@@ -75,21 +73,19 @@ async def test_static_buffer_summarization_trim_user_message(mock_summarizer_age
# Modify messages to ensure a user message is available to trim at the correct index
messages[5].role = MessageRole.user # Ensure a user message exists in trimming range
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@pytest.mark.asyncio
async def test_static_buffer_summarization_no_trim_no_summarization(mock_summarizer_agent, messages):
summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=15)
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:8], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:8], [])
assert len(updated_messages) == 8
assert summary == PREVIOUS_SUMMARY
assert not updated
mock_summarizer_agent.step.assert_not_called()
@@ -106,11 +102,10 @@ async def test_static_buffer_summarization_json_parsing_failure(mock_summarizer_
# Inject malformed JSON
messages[2].content = [TextContent(type="text", text="malformed json")]
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@@ -127,11 +122,10 @@ async def test_static_buffer_summarization_all_user_messages_trimmed(mock_summar
for i in range(12):
messages[i].role = MessageRole.user
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
assert len(updated_messages) == MESSAGE_BUFFER_MIN
assert len(updated_messages) == MESSAGE_BUFFER_MIN + 1
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()
@@ -148,10 +142,9 @@ async def test_static_buffer_summarization_no_assistant_messages_trimmed(mock_su
for i in range(12):
messages[i].role = MessageRole.assistant
updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY)
updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], [])
# Yeah, so this actually has to end on 1, because we basically can find no user, so we trim everything
assert len(updated_messages) == 1
assert updated
assert SUMMARY_TEXT in summary
mock_summarizer_agent.step.assert_called()