diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index a81d67fa..a749f279 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -452,12 +452,15 @@ async def retrieve_agent_context_window( agent_id: AgentId, server: "SyncServer" = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), + conversation_id: Optional[str] = Query( + None, description="Conversation ID to get context window for. If provided, uses messages from this conversation." + ), ): """ Retrieve the context window of a specific agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor) + return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor, conversation_id=conversation_id) class CreateAgentRequest(CreateAgent): @@ -2216,7 +2219,7 @@ async def capture_messages( messages_to_persist.append( Message( role=MessageRole.user, - content=[(TextContent(text=message["content"]))], + content=[TextContent(text=message["content"])], agent_id=agent_id, tool_calls=None, tool_call_id=None, @@ -2228,7 +2231,7 @@ async def capture_messages( messages_to_persist.append( Message( role=MessageRole.assistant, - content=[(TextContent(text=request.response_dict["content"]))], + content=[TextContent(text=request.response_dict["content"])], agent_id=agent_id, model=request.model, tool_calls=None, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index da2f5913..f7b3560b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -83,6 +83,7 @@ from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager, validate_block_limit_constraint from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator from letta.services.context_window_calculator.token_counter import create_token_counter +from letta.services.conversation_manager import ConversationManager from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.files_agents_manager import FileAgentManager from letta.services.helpers.agent_manager_helper import ( @@ -137,6 +138,7 @@ class AgentManager: self.identity_manager = IdentityManager() self.file_agent_manager = FileAgentManager() self.archive_manager = ArchiveManager() + self.conversation_manager = ConversationManager() @staticmethod def _should_exclude_model_from_base_tool_rules(model: str) -> bool: @@ -3388,7 +3390,7 @@ class AgentManager: @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @trace_method - async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + async def get_context_window(self, agent_id: str, actor: PydanticUser, conversation_id: Optional[str] = None) -> ContextWindowOverview: agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async( agent_id=agent_id, actor=actor, force=True, dry_run=True ) @@ -3402,6 +3404,16 @@ class AgentManager: agent_id=agent_id, ) + # If conversation_id is provided, get message_ids from the conversation + # Skip the first message ID (system message) since it's passed separately + message_ids = None + if conversation_id is not None: + conversation_message_ids = await self.conversation_manager.get_message_ids_for_conversation( + conversation_id=conversation_id, actor=actor + ) + # Skip the system message (first message) as it's handled separately + message_ids = conversation_message_ids[1:] if conversation_message_ids else [] + try: result = await calculator.calculate_context_window( agent_state=agent_state, @@ -3411,6 +3423,7 @@ class AgentManager: system_message_compiled=system_message, num_archival_memories=num_archival_memories, num_messages=num_messages, + message_ids=message_ids, ) except Exception as e: raise e diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py index bc51659e..80fe7659 100644 --- a/letta/services/context_window_calculator/context_window_calculator.py +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -105,9 +105,17 @@ class ContextWindowCalculator: system_message_compiled: Message, num_archival_memories: int, num_messages: int, + message_ids: Optional[List[str]] = None, ) -> ContextWindowOverview: - """Calculate context window information using the provided token counter""" - messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids[1:], actor=actor) + """Calculate context window information using the provided token counter + + Args: + message_ids: Optional list of message IDs to use instead of agent_state.message_ids. + If provided, should NOT include the system message ID (index 0). + """ + # Use provided message_ids or fall back to agent_state.message_ids[1:] + effective_message_ids = message_ids if message_ids is not None else agent_state.message_ids[1:] + messages = await message_manager.get_messages_by_ids_async(message_ids=effective_message_ids, actor=actor) in_context_messages = [system_message_compiled] + messages # Filter out None messages (can occur when system message is missing)