From e5e4ed5111df774e2931ca33a36e2210681e09c3 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 22 May 2025 20:30:41 -0700 Subject: [PATCH] chore: move context window estimate to `agent_manager` for full async (#2354) --- letta/server/rest_api/routers/v1/agents.py | 2 +- letta/server/server.py | 8 - letta/services/agent_manager.py | 285 ++++++++++++++++++++- tests/test_server.py | 5 +- 4 files changed, 287 insertions(+), 13 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index eaa14cfd..c042de03 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -212,7 +212,7 @@ async def retrieve_agent_context_window( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor) + return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor) except Exception as e: traceback.print_exc() raise e diff --git a/letta/server/server.py b/letta/server/server.py index e78aacc1..be3e33da 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1618,14 +1618,6 @@ class SyncServer(Server): def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" - def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview: - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - return letta_agent.get_context_window() - - async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview: - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - return await letta_agent.get_context_window_async() - def run_tool_from_source( self, actor: User, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 6ceefb15..2abeb95b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,9 +1,11 @@ import asyncio +import os from datetime import datetime, timezone from typing import Dict, List, Optional, Set, Tuple import numpy as np import sqlalchemy as sa +from openai.types.beta.function_tool import FunctionTool as OpenAITool from sqlalchemy import Select, and_, delete, func, insert, literal, or_, select, union_all from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -20,6 +22,7 @@ from letta.constants import ( ) from letta.embeddings import embedding_model from letta.helpers.datetime_helpers import get_utc_time +from letta.llm_api.llm_client import LLMClient from letta.log import get_logger from letta.orm import Agent as AgentModel from letta.orm import AgentPassage, AgentsTags @@ -42,9 +45,11 @@ from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent, get_prompt_ from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole, ProviderType from letta.schemas.group import Group as PydanticGroup from letta.schemas.group import ManagerType -from letta.schemas.memory import Memory +from letta.schemas.letta_message_content import TextContent +from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -79,7 +84,7 @@ from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings from letta.tracing import trace_method -from letta.utils import enforce_types, united_diff +from letta.utils import count_tokens, enforce_types, united_diff logger = get_logger(__name__) @@ -2332,3 +2337,279 @@ class AgentManager: # Extract the tag values from the result results = [row[0] for row in result.all()] return results + + async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": + return await self.get_context_window_from_anthropic_async(agent_id=agent_id, actor=actor) + return await self.get_context_window_from_tiktoken_async(agent_id=agent_id, actor=actor) + + async def get_context_window_from_anthropic_async(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + """Get the context window of the agent""" + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) + model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None + + # Grab the in-context messages + # conversion of messages to anthropic dict format, which is passed to the token counter + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + self.get_in_context_messages_async(agent_id=agent_id, actor=actor), + self.passage_manager.size_async(actor=actor, agent_id=agent_id), + self.message_manager.size_async(actor=actor, agent_id=agent_id), + ) + in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = None + core_memory = None + else: + # if no system message, fall back on agent's system prompt + system_prompt = agent_state.system + external_memory_summary = None + core_memory = None + + num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}]) + num_tokens_core_memory_coroutine = ( + anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}]) + if core_memory + else asyncio.sleep(0, result=0) + ) + num_tokens_external_memory_summary_coroutine = ( + anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}]) + if external_memory_summary + else asyncio.sleep(0, result=0) + ) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory_coroutine = anthropic_client.count_tokens( + model=model, messages=[{"role": "user", "content": summary_memory}] + ) + # with a summary message, the real messages start at index 2 + num_tokens_messages_coroutine = ( + anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:]) + if len(in_context_messages_anthropic) > 2 + else asyncio.sleep(0, result=0) + ) + + else: + summary_memory = None + num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0) + # with no summary message, the real messages start at index 1 + num_tokens_messages_coroutine = ( + anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:]) + if len(in_context_messages_anthropic) > 1 + else asyncio.sleep(0, result=0) + ) + + # tokens taken up by function definitions + if agent_state.tools and len(agent_state.tools) > 0: + available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in agent_state.tools] + num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens( + model=model, + tools=available_functions_definitions, + ) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0) + + ( + num_tokens_system, + num_tokens_core_memory, + num_tokens_external_memory_summary, + num_tokens_summary_memory, + num_tokens_messages, + num_tokens_available_functions_definitions, + ) = await asyncio.gather( + num_tokens_system_coroutine, + num_tokens_core_memory_coroutine, + num_tokens_external_memory_summary_coroutine, + num_tokens_summary_memory_coroutine, + num_tokens_messages_coroutine, + num_tokens_available_functions_definitions_coroutine, + ) + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) + + async def get_context_window_from_tiktoken_async(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + """Get the context window of the agent""" + from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages + + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + # Grab the in-context messages + # conversion of messages to OpenAI dict format, which is passed to the token counter + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + self.get_in_context_messages_async(agent_id=agent_id, actor=actor), + self.passage_manager.size_async(actor=actor, agent_id=agent_id), + self.message_manager.size_async(actor=actor, agent_id=agent_id), + ) + in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = "" + core_memory = "" + else: + # if no system message, fall back on agent's system prompt + system_prompt = agent_state.system + external_memory_summary = "" + core_memory = "" + + num_tokens_system = count_tokens(system_prompt) + num_tokens_core_memory = count_tokens(core_memory) + num_tokens_external_memory_summary = count_tokens(external_memory_summary) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory = count_tokens(text_content) + # with a summary message, the real messages start at index 2 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[2:], model=agent_state.llm_config.model) + if len(in_context_messages_openai) > 2 + else 0 + ) + + else: + summary_memory = None + num_tokens_summary_memory = 0 + # with no summary message, the real messages start at index 1 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[1:], model=agent_state.llm_config.model) + if len(in_context_messages_openai) > 1 + else 0 + ) + + # tokens taken up by function definitions + agent_state_tool_jsons = [t.json_schema for t in agent_state.tools] + if agent_state_tool_jsons: + available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] + num_tokens_available_functions_definitions = num_tokens_from_functions( + functions=agent_state_tool_jsons, model=agent_state.llm_config.model + ) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions = 0 + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) diff --git a/tests/test_server.py b/tests/test_server.py index b798d33c..519d95d6 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -467,9 +467,10 @@ def test_get_recall_memory(server, org_id, user, agent_id): # assert len(passage_none) == 0 -def test_get_context_window_overview(server: SyncServer, user, agent_id): +@pytest.mark.asyncio +async def test_get_context_window_overview(server: SyncServer, user, agent_id): """Test that the context window overview fetch works""" - overview = server.get_agent_context_window(agent_id=agent_id, actor=user) + overview = await server.agent_manager.get_context_window(agent_id=agent_id, actor=user) assert overview is not None # Run some basic checks