diff --git a/letta/agent.py b/letta/agent.py index b276ed79..41fdf6dd 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1269,6 +1269,119 @@ class Agent(BaseAgent): functions_definitions=available_functions_definitions, ) + async def get_context_window_async(self) -> ContextWindowOverview: + """Get the context window of the agent""" + # Grab the in-context messages + # conversion of messages to OpenAI dict format, which is passed to the token counter + in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user) + in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = "" + core_memory = "" + else: + # if no system message, fall back on agent's system prompt + system_prompt = self.agent_state.system + external_memory_summary = "" + core_memory = "" + + num_tokens_system = count_tokens(system_prompt) + num_tokens_core_memory = count_tokens(core_memory) + num_tokens_external_memory_summary = count_tokens(external_memory_summary) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory = count_tokens(text_content) + # with a summary message, the real messages start at index 2 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model) + if len(in_context_messages_openai) > 2 + else 0 + ) + + else: + summary_memory = None + num_tokens_summary_memory = 0 + # with no summary message, the real messages start at index 1 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model) + if len(in_context_messages_openai) > 1 + else 0 + ) + + # tokens taken up by function definitions + agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] + if agent_state_tool_jsons: + available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] + num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions = 0 + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=agent_manager_passage_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=self.agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) + def count_tokens(self) -> int: """Count the tokens in the current context window""" context_window_breakdown = self.get_context_window()