From 157634bf4403dcde96c17b8be33caaf5f92aa285 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 15 Oct 2024 13:32:37 -0700 Subject: [PATCH] feat: add `GET` route to get the breakdown of an agent's context window (#1889) --- letta/agent.py | 68 +++++++++++++++++++++- letta/schemas/memory.py | 37 ++++++++++++ letta/server/rest_api/routers/v1/agents.py | 15 +++++ letta/server/server.py | 17 +++++- 4 files changed, 135 insertions(+), 2 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 76ba7378..6186a6c7 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -23,6 +23,7 @@ from letta.errors import LLMError from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.llm_api_tools import create +from letta.local_llm.utils import num_tokens_from_messages from letta.memory import ArchivalMemory, RecallMemory, summarize_messages from letta.metadata import MetadataStore from letta.persistence_manager import LocalStateManager @@ -30,7 +31,7 @@ from letta.schemas.agent import AgentState, AgentStepResponse from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, OptionState -from letta.schemas.memory import Memory +from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, UpdateMessage from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import ( @@ -1421,6 +1422,71 @@ class Agent(BaseAgent): assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects" return messages + def get_context_window(self) -> ContextWindowOverview: + """Get the context window of the agent""" + + system_prompt = self.agent_state.system # TODO is this the current system or the initial system? + num_tokens_system = count_tokens(system_prompt) + core_memory = self.memory.compile() + num_tokens_core_memory = count_tokens(core_memory) + + # conversion of messages to OpenAI dict format, which is passed to the token counter + messages_openai_format = self.messages + + # Check if there's a summary message in the message queue + if ( + len(self._messages) > 1 + and self._messages[1].role == MessageRole.user + and isinstance(self._messages[1].text, str) + # TODO remove hardcoding + and "The following is a summary of the previous " in self._messages[1].text + ): + # Summary message exists + assert self._messages[1].text is not None + summary_memory = self._messages[1].text + num_tokens_summary_memory = count_tokens(self._messages[1].text) + # with a summary message, the real messages start at index 2 + num_tokens_messages = ( + num_tokens_from_messages(messages=messages_openai_format[2:], model=self.model) if len(messages_openai_format) > 2 else 0 + ) + + else: + summary_memory = None + num_tokens_summary_memory = 0 + # with no summary message, the real messages start at index 1 + num_tokens_messages = ( + num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0 + ) + + num_archival_memory = self.persistence_manager.archival_memory.storage.size() + num_recall_memory = self.persistence_manager.recall_memory.storage.size() + external_memory_summary = compile_memory_metadata_block( + memory_edit_timestamp=get_utc_time(), # dummy timestamp + archival_memory=self.persistence_manager.archival_memory, + recall_memory=self.persistence_manager.recall_memory, + ) + num_tokens_external_memory_summary = count_tokens(external_memory_summary) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(self._messages), + num_archival_memory=num_archival_memory, + num_recall_memory=num_recall_memory, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + # top-level information + context_window_size_max=self.agent_state.llm_config.context_window, + context_window_size_current=num_tokens_system + num_tokens_core_memory + num_tokens_summary_memory + num_tokens_messages, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=self._messages, + ) + def save_agent(agent: Agent, ms: MetadataStore): """Save agent to metadata store""" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 42ce6ec1..91d52cc3 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -8,6 +8,43 @@ if TYPE_CHECKING: from letta.agent import Agent from letta.schemas.block import Block +from letta.schemas.message import Message + + +class ContextWindowOverview(BaseModel): + """ + Overview of the context window, including the number of messages and tokens. + """ + + # top-level information + context_window_size_max: int = Field(..., description="The maximum amount of tokens the context window can hold.") + context_window_size_current: int = Field(..., description="The current number of tokens in the context window.") + + # context window breakdown (in messages) + # (technically not in the context window, but useful to know) + num_messages: int = Field(..., description="The number of messages in the context window.") + num_archival_memory: int = Field(..., description="The number of messages in the archival memory.") + num_recall_memory: int = Field(..., description="The number of messages in the recall memory.") + num_tokens_external_memory_summary: int = Field( + ..., description="The number of tokens in the external memory summary (archival + recall metadata)." + ) + + # context window breakdown (in tokens) + # this should all add up to context_window_size_current + + num_tokens_system: int = Field(..., description="The number of tokens in the system prompt.") + system_prompt: str = Field(..., description="The content of the system prompt.") + + num_tokens_core_memory: int = Field(..., description="The number of tokens in the core memory.") + core_memory: str = Field(..., description="The content of the core memory.") + + num_tokens_summary_memory: int = Field(..., description="The number of tokens in the summary memory.") + summary_memory: Optional[str] = Field(None, description="The content of the summary memory.") + + num_tokens_messages: int = Field(..., description="The number of tokens in the messages list.") + # TODO make list of messages? + # messages: List[dict] = Field(..., description="The messages in the context window.") + messages: List[Message] = Field(..., description="The messages in the context window.") class Memory(BaseModel, validate_assignment=True): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 6d42c2fb..b7a2c9bf 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -19,6 +19,7 @@ from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ( ArchivalMemorySummary, BasicBlockMemory, + ContextWindowOverview, CreateArchivalMemory, Memory, RecallMemorySummary, @@ -51,6 +52,20 @@ def list_agents( return server.list_agents(user_id=actor.id) +@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="get_agent_context_window") +def get_agent_context_window( + agent_id: str, + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Retrieve the context window of a specific agent. + """ + actor = server.get_user_or_default(user_id=user_id) + + return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id) + + @router.post("/", response_model=AgentState, operation_id="create_agent") def create_agent( agent: CreateAgent = Body(...), diff --git a/letta/server/server.py b/letta/server/server.py index 16df2be9..900c5217 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -72,7 +72,12 @@ from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary +from letta.schemas.memory import ( + ArchivalMemorySummary, + ContextWindowOverview, + Memory, + RecallMemorySummary, +) from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.passage import Passage @@ -2019,3 +2024,13 @@ class SyncServer(Server): def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" + + def get_agent_context_window( + self, + user_id: str, + agent_id: str, + ) -> ContextWindowOverview: + + # Get the current message + letta_agent = self._get_or_load_agent(agent_id=agent_id) + return letta_agent.get_context_window()