diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 05a2a867..e21ac9ac 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -218,6 +218,7 @@ class LettaAgent(BaseAgent): use_assistant_message: bool = True, request_start_timestamp_ns: int | None = None, include_return_message_types: list[MessageType] | None = None, + run_id: str | None = None, ): agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, @@ -330,6 +331,7 @@ class LettaAgent(BaseAgent): tool_rules_solver, agent_step_span, step_metrics, + run_id=run_id, ) in_context_messages = current_in_context_messages + new_in_context_messages @@ -549,6 +551,7 @@ class LettaAgent(BaseAgent): llm_config=agent_state.llm_config, total_tokens=usage.total_tokens, force=False, + run_id=run_id, ) await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False) @@ -677,6 +680,7 @@ class LettaAgent(BaseAgent): tool_rules_solver, agent_step_span, step_metrics, + run_id=run_id, ) in_context_messages = current_in_context_messages + new_in_context_messages @@ -882,6 +886,7 @@ class LettaAgent(BaseAgent): llm_config=agent_state.llm_config, total_tokens=usage.total_tokens, force=False, + run_id=run_id, ) await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False) @@ -908,6 +913,7 @@ class LettaAgent(BaseAgent): use_assistant_message: bool = True, request_start_timestamp_ns: int | None = None, include_return_message_types: list[MessageType] | None = None, + run_id: str | None = None, ) -> AsyncGenerator[str, None]: """ Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens. @@ -1027,6 +1033,7 @@ class LettaAgent(BaseAgent): agent_state, llm_client, tool_rules_solver, + run_id=run_id, ) step_progression = StepProgression.STREAM_RECEIVED @@ -1378,6 +1385,7 @@ class LettaAgent(BaseAgent): llm_config=agent_state.llm_config, total_tokens=usage.total_tokens, force=False, + run_id=run_id, ) await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False) @@ -1441,6 +1449,7 @@ class LettaAgent(BaseAgent): tool_rules_solver: ToolRulesSolver, agent_step_span: "Span", step_metrics: StepMetrics, + run_id: str | None = None, ) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None: for attempt in range(self.max_summarization_retries + 1): try: @@ -1488,6 +1497,7 @@ class LettaAgent(BaseAgent): new_letta_messages=new_in_context_messages, llm_config=agent_state.llm_config, force=True, + run_id=run_id, ) new_in_context_messages = [] log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}") @@ -1503,6 +1513,7 @@ class LettaAgent(BaseAgent): agent_state: AgentState, llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, + run_id: str | None = None, ) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None: for attempt in range(self.max_summarization_retries + 1): try: @@ -1555,6 +1566,7 @@ class LettaAgent(BaseAgent): new_letta_messages=new_in_context_messages, llm_config=agent_state.llm_config, force=True, + run_id=run_id, ) new_in_context_messages: list[Message] = [] log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}") @@ -1568,10 +1580,17 @@ class LettaAgent(BaseAgent): new_letta_messages: list[Message], llm_config: LLMConfig, force: bool, + run_id: str | None = None, + step_id: str | None = None, ) -> list[Message]: if isinstance(e, ContextWindowExceededError): return await self._rebuild_context_window( - in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force + in_context_messages=in_context_messages, + new_letta_messages=new_letta_messages, + llm_config=llm_config, + force=force, + run_id=run_id, + step_id=step_id, ) else: raise llm_client.handle_llm_error(e) @@ -1584,6 +1603,8 @@ class LettaAgent(BaseAgent): llm_config: LLMConfig, total_tokens: int | None = None, force: bool = False, + run_id: str | None = None, + step_id: str | None = None, ) -> list[Message]: # If total tokens is reached, we truncate down # TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc. @@ -1597,6 +1618,8 @@ class LettaAgent(BaseAgent): new_letta_messages=new_letta_messages, force=True, clear=True, + run_id=run_id, + step_id=step_id, ) else: # NOTE (Sarah): Seems like this is doing nothing? @@ -1606,6 +1629,8 @@ class LettaAgent(BaseAgent): new_in_context_messages, updated = await self.summarizer.summarize( in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, + run_id=run_id, + step_id=step_id, ) await self.agent_manager.update_message_ids_async( agent_id=self.agent_id, diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 14b53440..34569655 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -236,6 +236,7 @@ class LettaAgentV2(BaseAgentV2): new_letta_messages=self.response_messages, total_tokens=self.usage.total_tokens, force=False, + run_id=run_id, ) if self.stop_reason is None: @@ -343,6 +344,7 @@ class LettaAgentV2(BaseAgentV2): new_letta_messages=self.response_messages, total_tokens=self.usage.total_tokens, force=False, + run_id=run_id, ) except: @@ -488,6 +490,8 @@ class LettaAgentV2(BaseAgentV2): in_context_messages=messages, new_letta_messages=self.response_messages, force=True, + run_id=run_id, + step_id=step_id, ) else: raise e @@ -1246,6 +1250,8 @@ class LettaAgentV2(BaseAgentV2): new_letta_messages: list[Message], total_tokens: int | None = None, force: bool = False, + run_id: str | None = None, + step_id: str | None = None, ) -> list[Message]: self.logger.warning("Running deprecated v2 summarizer. This should be removed in the future.") # always skip summarization if last message is an approval request message @@ -1268,6 +1274,8 @@ class LettaAgentV2(BaseAgentV2): new_letta_messages=new_letta_messages, force=True, clear=True, + run_id=run_id, + step_id=step_id, ) else: # NOTE (Sarah): Seems like this is doing nothing? @@ -1277,6 +1285,8 @@ class LettaAgentV2(BaseAgentV2): new_in_context_messages, updated = await self.summarizer.summarize( in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, + run_id=run_id, + step_id=step_id, ) except Exception as e: self.logger.error(f"Failed to summarize conversation history: {e}") diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 6e6cda01..9c454bf2 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -769,7 +769,10 @@ class LettaAgentV3(LettaAgentV2): # TODO: might want to delay this checkpoint in case of corrupated state try: summary_message, messages, _ = await self.compact( - messages, trigger_threshold=self.agent_state.llm_config.context_window + messages, + trigger_threshold=self.agent_state.llm_config.context_window, + run_id=run_id, + step_id=step_id, ) self.logger.info("Summarization succeeded, continuing to retry LLM request") continue @@ -893,7 +896,12 @@ class LettaAgentV3(LettaAgentV2): self.logger.info( f"Context window exceeded (current: {self.context_token_estimate}, threshold: {self.agent_state.llm_config.context_window}), trying to compact messages" ) - summary_message, messages, _ = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window) + summary_message, messages, _ = await self.compact( + messages, + trigger_threshold=self.agent_state.llm_config.context_window, + run_id=run_id, + step_id=step_id, + ) # TODO: persist + return the summary message # TODO: convert this to a SummaryMessage self.response_messages.append(summary_message) @@ -1463,7 +1471,12 @@ class LettaAgentV3(LettaAgentV2): @trace_method async def compact( - self, messages, trigger_threshold: Optional[int] = None, compaction_settings: Optional["CompactionSettings"] = None + self, + messages, + trigger_threshold: Optional[int] = None, + compaction_settings: Optional["CompactionSettings"] = None, + run_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> tuple[Message, list[Message], str]: """Compact the current in-context messages for this agent. @@ -1502,6 +1515,10 @@ class LettaAgentV3(LettaAgentV2): llm_config=summarizer_llm_config, summarizer_config=summarizer_config, in_context_messages=messages, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + step_id=step_id, ) elif summarizer_config.mode == "sliding_window": try: @@ -1510,6 +1527,10 @@ class LettaAgentV3(LettaAgentV2): llm_config=summarizer_llm_config, summarizer_config=summarizer_config, in_context_messages=messages, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + step_id=step_id, ) except Exception as e: self.logger.error(f"Sliding window summarization failed with exception: {str(e)}. Falling back to all mode.") @@ -1518,6 +1539,10 @@ class LettaAgentV3(LettaAgentV2): llm_config=summarizer_llm_config, summarizer_config=summarizer_config, in_context_messages=messages, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + step_id=step_id, ) summarization_mode_used = "all" else: @@ -1551,6 +1576,10 @@ class LettaAgentV3(LettaAgentV2): llm_config=self.agent_state.llm_config, summarizer_config=summarizer_config, in_context_messages=compacted_messages, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + step_id=step_id, ) summarization_mode_used = "all" diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index d2674033..ed855e84 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -49,6 +49,8 @@ class Summarizer: message_manager: Optional[MessageManager] = None, actor: Optional[User] = None, agent_id: Optional[str] = None, + run_id: Optional[str] = None, + step_id: Optional[str] = None, ): self.mode = mode @@ -64,6 +66,8 @@ class Summarizer: self.message_manager = message_manager self.actor = actor self.agent_id = agent_id + self.run_id = run_id + self.step_id = step_id @trace_method async def summarize( @@ -72,6 +76,8 @@ class Summarizer: new_letta_messages: List[Message], force: bool = False, clear: bool = False, + run_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> Tuple[List[Message], bool]: """ Summarizes or trims in_context_messages according to the chosen mode, @@ -81,6 +87,8 @@ class Summarizer: in_context_messages: The existing messages in the conversation's context. new_letta_messages: The newly added Letta messages (just appended). force: Force summarize even if the criteria is not met + run_id: Optional run ID for telemetry (overrides instance default) + step_id: Optional step ID for telemetry (overrides instance default) Returns: (updated_messages, summary_message) @@ -88,6 +96,9 @@ class Summarizer: summary_message: Optional summarization message that was created (could be appended to the conversation if desired) """ + effective_run_id = run_id if run_id is not None else self.run_id + effective_step_id = step_id if step_id is not None else self.step_id + if self.mode == SummarizationMode.STATIC_MESSAGE_BUFFER: return self._static_buffer_summarization( in_context_messages, @@ -101,6 +112,8 @@ class Summarizer: new_letta_messages, force=force, clear=clear, + run_id=effective_run_id, + step_id=effective_step_id, ) else: # Fallback or future logic @@ -124,6 +137,8 @@ class Summarizer: new_letta_messages: List[Message], force: bool = False, clear: bool = False, + run_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> Tuple[List[Message], bool]: """Summarization as implemented in the original MemGPT loop, but using message count instead of token count. Evict a partial amount of messages, and replace message[1] with a recursive summary. @@ -173,6 +188,8 @@ class Summarizer: include_ack=True, agent_id=self.agent_id, agent_tags=agent_state.tags, + run_id=run_id if run_id is not None else self.run_id, + step_id=step_id if step_id is not None else self.step_id, ) # TODO add counts back @@ -432,6 +449,7 @@ async def simple_summary( agent_id: str | None = None, agent_tags: List[str] | None = None, run_id: str | None = None, + step_id: str | None = None, ) -> str: """Generate a simple summary from a list of messages. @@ -454,6 +472,7 @@ async def simple_summary( agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, + step_id=step_id, call_type="summarization", ) diff --git a/letta/services/summarizer/summarizer_all.py b/letta/services/summarizer/summarizer_all.py index 7918f89f..3d9b2ffa 100644 --- a/letta/services/summarizer/summarizer_all.py +++ b/letta/services/summarizer/summarizer_all.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from letta.log import get_logger from letta.otel.tracing import trace_method @@ -20,7 +20,11 @@ async def summarize_all( # Actual summarization configuration summarizer_config: CompactionSettings, in_context_messages: List[Message], - # new_messages: List[Message], + # Telemetry context + agent_id: Optional[str] = None, + agent_tags: Optional[List[str]] = None, + run_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> str: """ Summarize the entire conversation history into a single summary. @@ -60,6 +64,10 @@ async def summarize_all( actor=actor, include_ack=bool(summarizer_config.prompt_acknowledgement), prompt=summarizer_config.prompt, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + step_id=step_id, ) logger.info(f"Summarized {len(messages_to_summarize)} messages") diff --git a/letta/services/summarizer/summarizer_sliding_window.py b/letta/services/summarizer/summarizer_sliding_window.py index d38ca58e..10a409d2 100644 --- a/letta/services/summarizer/summarizer_sliding_window.py +++ b/letta/services/summarizer/summarizer_sliding_window.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple from letta.helpers.message_helper import convert_message_creates_to_messages from letta.log import get_logger @@ -50,7 +50,11 @@ async def summarize_via_sliding_window( llm_config: LLMConfig, summarizer_config: CompactionSettings, in_context_messages: List[Message], - # new_messages: List[Message], + # Telemetry context + agent_id: Optional[str] = None, + agent_tags: Optional[List[str]] = None, + run_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> Tuple[str, List[Message]]: """ If the total tokens is greater than the context window limit (or force=True), @@ -138,6 +142,10 @@ async def summarize_via_sliding_window( actor=actor, include_ack=bool(summarizer_config.prompt_acknowledgement), prompt=summarizer_config.prompt, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + step_id=step_id, ) if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars: diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 39998108..fec63a4d 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -816,7 +816,7 @@ async def test_v3_compact_uses_compaction_settings_model_and_model_settings(serv captured_llm_config: dict = {} - async def fake_simple_summary(messages, llm_config, actor, include_ack=True, prompt=None): # type: ignore[override] + async def fake_simple_summary(messages, llm_config, actor, include_ack=True, prompt=None, **kwargs): # type: ignore[override] captured_llm_config["value"] = llm_config return "summary text"