diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 988c6ca4..a752474a 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -67,7 +67,8 @@ class BaseAgent(ABC): """ raise NotImplementedError - def pre_process_input_message(self, input_messages: List[MessageCreate]) -> Any: + @staticmethod + def pre_process_input_message(input_messages: List[MessageCreate]) -> Any: """ Pre-process function to run on the input_message. """ diff --git a/letta/agents/ephemeral_summary_agent.py b/letta/agents/ephemeral_summary_agent.py index 572e3c78..55d610c2 100644 --- a/letta/agents/ephemeral_summary_agent.py +++ b/letta/agents/ephemeral_summary_agent.py @@ -1,27 +1,28 @@ -from pathlib import Path -from typing import AsyncGenerator, Dict, List - -from openai import AsyncOpenAI +from typing import AsyncGenerator, List from letta.agents.base_agent import BaseAgent from letta.constants import DEFAULT_MAX_STEPS +from letta.helpers.message_helper import convert_message_creates_to_messages +from letta.llm_api.llm_client import LLMClient +from letta.log import get_logger from letta.orm.errors import NoResultFound +from letta.prompts.gpt_system import get_system_text from letta.schemas.block import Block, BlockUpdate from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate -from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager +logger = get_logger(__name__) + class EphemeralSummaryAgent(BaseAgent): """ - A stateless summarization agent (thin wrapper around OpenAI) - - # TODO: Extend to more clients + A stateless summarization agent that utilizes the caller's LLM client to summarize the conversation. + TODO (cliandy): allow the summarizer to use another llm_config from the main agent maybe? """ def __init__( @@ -35,7 +36,7 @@ class EphemeralSummaryAgent(BaseAgent): ): super().__init__( agent_id=agent_id, - openai_client=AsyncOpenAI(), + openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor, @@ -65,17 +66,33 @@ class EphemeralSummaryAgent(BaseAgent): input_message = input_messages[0] input_message.content[0].text += f"\n\n--- Previous Summary ---\n{block.value}\n" - openai_messages = self.pre_process_input_message(input_messages=input_messages) - request = self._build_openai_request(openai_messages) + # Gets the LLMCLient based on the calling agent's LLM Config + agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor) + llm_client = LLMClient.create( + provider_type=agent_state.llm_config.model_endpoint_type, + put_inner_thoughts_first=True, + actor=self.actor, + ) - # TODO: Extend to generic client - chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True)) - summary = chat_completion.choices[0].message.content.strip() + system_message_create = MessageCreate( + role=MessageRole.system, + content=[TextContent(text=get_system_text("summary_system_prompt"))], + ) + messages = convert_message_creates_to_messages( + message_creates=[system_message_create] + input_messages, + agent_id=self.agent_id, + timezone=agent_state.timezone, + ) + + request_data = llm_client.build_request_data(messages, agent_state.llm_config, tools=[]) + response_data = await llm_client.request_async(request_data, agent_state.llm_config) + response = llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config) + summary = response.choices[0].message.content.strip() await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor) - print(block) - print(summary) + logger.debug("block:", block) + logger.debug("summary:", summary) return [ Message( @@ -84,22 +101,5 @@ class EphemeralSummaryAgent(BaseAgent): ) ] - def _build_openai_request(self, openai_messages: List[Dict]) -> ChatCompletionRequest: - current_dir = Path(__file__).parent - file_path = current_dir / "prompts" / "summary_system_prompt.txt" - with open(file_path, "r") as file: - system = file.read() - - system_message = [{"role": "system", "content": system}] - - openai_request = ChatCompletionRequest( - model="gpt-4o", - messages=system_message + openai_messages, - user=self.actor.id, - max_completion_tokens=4096, - temperature=0.7, - ) - return openai_request - async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]: raise NotImplementedError("EphemeralAgent does not support async step.") diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index cdf2c52a..721f915a 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -58,11 +58,15 @@ from letta.services.summarizer.enums import SummarizationMode from letta.services.summarizer.summarizer import Summarizer from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager -from letta.settings import model_settings +from letta.settings import model_settings, summarizer_settings from letta.system import package_function_response from letta.types import JsonDict from letta.utils import log_telemetry, validate_function_response +logger = get_logger(__name__) + +DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary" + class LettaAgent(BaseAgent): @@ -77,11 +81,11 @@ class LettaAgent(BaseAgent): actor: User, step_manager: StepManager = NoopStepManager(), telemetry_manager: TelemetryManager = NoopTelemetryManager(), - summary_block_label: str = "conversation_summary", - message_buffer_limit: int = 60, # TODO: Make this configurable - message_buffer_min: int = 15, # TODO: Make this configurable - enable_summarization: bool = True, # TODO: Make this configurable - max_summarization_retries: int = 3, # TODO: Make this configurable + summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL, + message_buffer_limit: int = summarizer_settings.message_buffer_limit, + message_buffer_min: int = summarizer_settings.message_buffer_min, + enable_summarization: bool = summarizer_settings.enable_summarization, + max_summarization_retries: int = summarizer_settings.max_summarization_retries, ): super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor) @@ -117,7 +121,7 @@ class LettaAgent(BaseAgent): ) self.summarizer = Summarizer( - mode=SummarizationMode.STATIC_MESSAGE_BUFFER, + mode=SummarizationMode(summarizer_settings.mode), summarizer_agent=self.summarization_agent, # TODO: Make this configurable message_buffer_limit=message_buffer_limit, diff --git a/letta/agents/prompts/summary_system_prompt.txt b/letta/prompts/system/summary_system_prompt.txt similarity index 100% rename from letta/agents/prompts/summary_system_prompt.txt rename to letta/prompts/system/summary_system_prompt.txt diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index f8ae1d88..9e741dd8 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -302,6 +302,7 @@ class OpenAIProvider(Provider): if self.base_url == "https://api.openai.com/v1": allowed_types = ["gpt-4", "o1", "o3", "o4"] # NOTE: o1-mini and o1-preview do not support tool calling + # NOTE: o1-mini does not support system messages # NOTE: o1-pro is only available in Responses API disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"] skip = True diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py index 9fe11ee8..cdb401ed 100644 --- a/letta/server/rest_api/routers/v1/steps.py +++ b/letta/server/rest_api/routers/v1/steps.py @@ -90,7 +90,7 @@ async def add_feedback( @router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id") -def update_step_transaction_id( +async def update_step_transaction_id( step_id: str, transaction_id: str, actor_id: Optional[str] = Header(None, alias="user_id"), @@ -102,6 +102,6 @@ def update_step_transaction_id( actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id) + return await server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id) except NoResultFound: raise HTTPException(status_code=404, detail="Step not found") diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index d4d5c7b3..138071ea 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -173,7 +173,7 @@ class StepManager: @enforce_types @trace_method - def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep: + async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep: """Update the transaction ID for a step. Args: @@ -187,15 +187,15 @@ class StepManager: Raises: NoResultFound: If the step does not exist """ - with db_registry.session() as session: - step = session.get(StepModel, step_id) + async with db_registry.async_session() as session: + step = await session.get(StepModel, step_id) if not step: raise NoResultFound(f"Step with id {step_id} does not exist") if step.organization_id != actor.organization_id: raise Exception("Unauthorized") step.tid = transaction_id - session.commit() + await session.commit() return step.to_pydantic() def _verify_job_access( @@ -226,8 +226,8 @@ class StepManager: raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") return job + @staticmethod async def _verify_job_access_async( - self, session: AsyncSession, job_id: str, actor: PydanticUser, diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index fa6d18d5..7795117e 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -11,6 +11,7 @@ from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate from letta.services.summarizer.enums import SummarizationMode +from letta.templates.template_helper import render_template logger = get_logger(__name__) @@ -123,30 +124,13 @@ class Summarizer: formatted_evicted_messages = [f"{i}. {msg}" for (i, msg) in enumerate(formatted_evicted_messages)] formatted_in_context_messages = [f"{i + offset}. {msg}" for (i, msg) in enumerate(formatted_in_context_messages)] - evicted_messages_str = "\n".join(formatted_evicted_messages) - in_context_messages_str = "\n".join(formatted_in_context_messages) - # Base prompt - prompt_header = ( - f"You’re a memory-recall helper for an AI that can only keep the last {retain_count} messages. " - "Scan the conversation history, focusing on messages about to drop out of that window, " - "and write crisp notes that capture any important facts or insights about the conversation history so they aren’t lost." + summary_request_text = render_template( + "summary_request_text.j2", + retain_count=retain_count, + evicted_messages=formatted_evicted_messages, + in_context_messages=formatted_in_context_messages, ) - # Sections - evicted_section = f"\n\n(Older) Evicted Messages:\n{evicted_messages_str}" if evicted_messages_str.strip() else "" - in_context_section = "" - - if retain_count > 0 and in_context_messages_str.strip(): - in_context_section = f"\n\n(Newer) In-Context Messages:\n{in_context_messages_str}" - elif retain_count == 0: - prompt_header = ( - "You’re a memory-recall helper for an AI that is about to forget all prior messages. " - "Scan the conversation history and write crisp notes that capture any important facts or insights about the conversation history." - ) - - # Compose final prompt - summary_request_text = prompt_header + evicted_section + in_context_section - # Fire-and-forget the summarization task self.fire_and_forget( self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])]) diff --git a/letta/settings.py b/letta/settings.py index b959cdf8..5554951c 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -6,6 +6,7 @@ from pydantic import AliasChoices, Field from pydantic_settings import BaseSettings, SettingsConfigDict from letta.local_llm.constants import DEFAULT_WRAPPER_NAME +from letta.services.summarizer.enums import SummarizationMode class ToolSettings(BaseSettings): @@ -38,6 +39,13 @@ class ToolSettings(BaseSettings): class SummarizerSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="letta_summarizer_", extra="ignore") + mode: SummarizationMode = SummarizationMode.STATIC_MESSAGE_BUFFER + message_buffer_limit: int = 60 + message_buffer_min: int = 15 + enable_summarization: bool = True + max_summarization_retries: int = 3 + + # TODO(cliandy): the below settings are tied to old summarization and should be deprecated or moved # Controls if we should evict all messages # TODO: Can refactor this into an enum if we have a bunch of different kinds of summarizers evict_all_messages: bool = False diff --git a/letta/templates/summary_request_text.j2 b/letta/templates/summary_request_text.j2 new file mode 100644 index 00000000..1cf57176 --- /dev/null +++ b/letta/templates/summary_request_text.j2 @@ -0,0 +1,19 @@ +{% if retain_count == 0 %} +You’re a memory-recall helper for an AI that is about to forget all prior messages. Scan the conversation history and write crisp notes that capture any important facts or insights about the conversation history. +{% else %} +You’re a memory-recall helper for an AI that can only keep the last {{ retain_count }} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost. +{% endif %} + +{% if evicted_messages %} +(Older) Evicted Messages: +{% for item in evicted_messages %} + {{ item }} +{% endfor %} +{% endif %} + +{% if retain_count > 0 and in_context_messages %} +(Newer) In-Context Messages: +{% for item in in_context_messages %} + {{ item }} +{% endfor %} +{% endif %} diff --git a/tests/configs/llm_model_configs/openai-o1-mini.json b/tests/configs/llm_model_configs/openai-o1-mini.json deleted file mode 100644 index fbfa0c01..00000000 --- a/tests/configs/llm_model_configs/openai-o1-mini.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "context_window": 128000, - "model": "o1-mini", - "model_endpoint_type": "openai", - "model_endpoint": "https://api.openai.com/v1", - "model_wrapper": null, - "temperature": 1.0 -} diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 14fa877f..f84f7d46 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -103,7 +103,6 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [ all_configs = [ "openai-gpt-4o-mini.json", "openai-o1.json", - "openai-o1-mini.json", "openai-o3.json", "openai-o4-mini.json", "azure-gpt-4o-mini.json", @@ -116,7 +115,7 @@ all_configs = [ "gemini-2.5-flash-vertex.json", "gemini-2.5-pro-vertex.json", "together-qwen-2.5-72b-instruct.json", - "ollama.json", + # "ollama.json", # TODO (cliandy): enable this in ollama testing ] @@ -1215,7 +1214,7 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM new_llm_config = llm_config.model_dump() new_llm_config["context_window"] = 3000 pinned_context_window_llm_config = LLMConfig(**new_llm_config) - + print("::LLM::", llm_config, new_llm_config) send_message_tool = client.tools.list(name="send_message")[0] temp_agent_state = client.agents.create( include_base_tools=False, diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index d0831e00..fee463d5 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -54,6 +54,7 @@ def actor(server, org_id): server.user_manager.delete_user_by_id(user.id) +@pytest.mark.flaky(max_runs=3) @pytest.mark.asyncio(loop_scope="module") async def test_sleeptime_group_chat(server, actor): # 0. Refresh base tools diff --git a/tests/pytest.ini b/tests/pytest.ini index 35efd715..42dfe970 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -2,6 +2,7 @@ pythonpath = /letta testpaths = /tests asyncio_mode = auto +asyncio_default_fixture_loop_scope = function filterwarnings = ignore::pytest.PytestRemovedIn9Warning # suppresses the warnings we see with the event_loop fixture diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 1e56adc9..46ed45e0 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -443,6 +443,7 @@ def test_function_return_limit(disable_e2b_api_key, client: LettaSDKClient, agen assert "function output was truncated " in res +@pytest.mark.flaky(max_runs=3) def test_function_always_error(client: LettaSDKClient, agent: AgentState): """Test to see if function that errors works correctly""" diff --git a/tests/test_static_buffer_summarize.py b/tests/test_static_buffer_summarize.py index d50c4699..fb6090c3 100644 --- a/tests/test_static_buffer_summarize.py +++ b/tests/test_static_buffer_summarize.py @@ -14,7 +14,6 @@ from letta.services.summarizer.summarizer import Summarizer # Constants for test parameters MESSAGE_BUFFER_LIMIT = 10 MESSAGE_BUFFER_MIN = 3 -PREVIOUS_SUMMARY = "Previous summary" SUMMARY_TEXT = "Summarized memory" @@ -22,6 +21,7 @@ SUMMARY_TEXT = "Summarized memory" def mock_summarizer_agent(): agent = AsyncMock(spec=BaseAgent) agent.step.return_value = [Message(role=MessageRole.assistant, content=[TextContent(type="text", text=SUMMARY_TEXT)])] + agent.update_message_transcript = AsyncMock() return agent @@ -40,10 +40,9 @@ def messages(): @pytest.mark.asyncio async def test_static_buffer_summarization_no_trim_needed(mock_summarizer_agent, messages): summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=20) - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:5], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:5], []) assert len(updated_messages) == 5 - assert summary == PREVIOUS_SUMMARY assert not updated @@ -55,11 +54,10 @@ async def test_static_buffer_summarization_trim_needed(mock_summarizer_agent, me message_buffer_limit=MESSAGE_BUFFER_LIMIT, message_buffer_min=MESSAGE_BUFFER_MIN, ) - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], []) - assert len(updated_messages) == MESSAGE_BUFFER_MIN # Should be trimmed down to min buffer size + assert len(updated_messages) == MESSAGE_BUFFER_MIN assert updated - assert SUMMARY_TEXT in summary mock_summarizer_agent.step.assert_called() @@ -75,21 +73,19 @@ async def test_static_buffer_summarization_trim_user_message(mock_summarizer_age # Modify messages to ensure a user message is available to trim at the correct index messages[5].role = MessageRole.user # Ensure a user message exists in trimming range - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], []) assert len(updated_messages) == MESSAGE_BUFFER_MIN assert updated - assert SUMMARY_TEXT in summary mock_summarizer_agent.step.assert_called() @pytest.mark.asyncio async def test_static_buffer_summarization_no_trim_no_summarization(mock_summarizer_agent, messages): summarizer = Summarizer(SummarizationMode.STATIC_MESSAGE_BUFFER, mock_summarizer_agent, message_buffer_limit=15) - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:8], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:8], []) assert len(updated_messages) == 8 - assert summary == PREVIOUS_SUMMARY assert not updated mock_summarizer_agent.step.assert_not_called() @@ -106,11 +102,10 @@ async def test_static_buffer_summarization_json_parsing_failure(mock_summarizer_ # Inject malformed JSON messages[2].content = [TextContent(type="text", text="malformed json")] - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], []) assert len(updated_messages) == MESSAGE_BUFFER_MIN assert updated - assert SUMMARY_TEXT in summary mock_summarizer_agent.step.assert_called() @@ -127,11 +122,10 @@ async def test_static_buffer_summarization_all_user_messages_trimmed(mock_summar for i in range(12): messages[i].role = MessageRole.user - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], []) - assert len(updated_messages) == MESSAGE_BUFFER_MIN + assert len(updated_messages) == MESSAGE_BUFFER_MIN + 1 assert updated - assert SUMMARY_TEXT in summary mock_summarizer_agent.step.assert_called() @@ -148,10 +142,9 @@ async def test_static_buffer_summarization_no_assistant_messages_trimmed(mock_su for i in range(12): messages[i].role = MessageRole.assistant - updated_messages, summary, updated = await summarizer._static_buffer_summarization(messages[:12], [], PREVIOUS_SUMMARY) + updated_messages, updated = summarizer._static_buffer_summarization(messages[:12], []) # Yeah, so this actually has to end on 1, because we basically can find no user, so we trim everything assert len(updated_messages) == 1 assert updated - assert SUMMARY_TEXT in summary mock_summarizer_agent.step.assert_called()