feat: add GET route to get the breakdown of an agent's context window (#1889)
This commit is contained in:
@@ -23,6 +23,7 @@ from letta.errors import LLMError
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.utils import num_tokens_from_messages
|
||||
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.persistence_manager import LocalStateManager
|
||||
@@ -30,7 +31,7 @@ from letta.schemas.agent import AgentState, AgentStepResponse
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, OptionState
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, UpdateMessage
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
@@ -1421,6 +1422,71 @@ class Agent(BaseAgent):
|
||||
assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
|
||||
return messages
|
||||
|
||||
def get_context_window(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
|
||||
system_prompt = self.agent_state.system # TODO is this the current system or the initial system?
|
||||
num_tokens_system = count_tokens(system_prompt)
|
||||
core_memory = self.memory.compile()
|
||||
num_tokens_core_memory = count_tokens(core_memory)
|
||||
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
messages_openai_format = self.messages
|
||||
|
||||
# Check if there's a summary message in the message queue
|
||||
if (
|
||||
len(self._messages) > 1
|
||||
and self._messages[1].role == MessageRole.user
|
||||
and isinstance(self._messages[1].text, str)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in self._messages[1].text
|
||||
):
|
||||
# Summary message exists
|
||||
assert self._messages[1].text is not None
|
||||
summary_memory = self._messages[1].text
|
||||
num_tokens_summary_memory = count_tokens(self._messages[1].text)
|
||||
# with a summary message, the real messages start at index 2
|
||||
num_tokens_messages = (
|
||||
num_tokens_from_messages(messages=messages_openai_format[2:], model=self.model) if len(messages_openai_format) > 2 else 0
|
||||
)
|
||||
|
||||
else:
|
||||
summary_memory = None
|
||||
num_tokens_summary_memory = 0
|
||||
# with no summary message, the real messages start at index 1
|
||||
num_tokens_messages = (
|
||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||
)
|
||||
|
||||
num_archival_memory = self.persistence_manager.archival_memory.storage.size()
|
||||
num_recall_memory = self.persistence_manager.recall_memory.storage.size()
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||
archival_memory=self.persistence_manager.archival_memory,
|
||||
recall_memory=self.persistence_manager.recall_memory,
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(self._messages),
|
||||
num_archival_memory=num_archival_memory,
|
||||
num_recall_memory=num_recall_memory,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
# top-level information
|
||||
context_window_size_max=self.agent_state.llm_config.context_window,
|
||||
context_window_size_current=num_tokens_system + num_tokens_core_memory + num_tokens_summary_memory + num_tokens_messages,
|
||||
# context window breakdown (in tokens)
|
||||
num_tokens_system=num_tokens_system,
|
||||
system_prompt=system_prompt,
|
||||
num_tokens_core_memory=num_tokens_core_memory,
|
||||
core_memory=core_memory,
|
||||
num_tokens_summary_memory=num_tokens_summary_memory,
|
||||
summary_memory=summary_memory,
|
||||
num_tokens_messages=num_tokens_messages,
|
||||
messages=self._messages,
|
||||
)
|
||||
|
||||
|
||||
def save_agent(agent: Agent, ms: MetadataStore):
|
||||
"""Save agent to metadata store"""
|
||||
|
||||
@@ -8,6 +8,43 @@ if TYPE_CHECKING:
|
||||
from letta.agent import Agent
|
||||
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
class ContextWindowOverview(BaseModel):
|
||||
"""
|
||||
Overview of the context window, including the number of messages and tokens.
|
||||
"""
|
||||
|
||||
# top-level information
|
||||
context_window_size_max: int = Field(..., description="The maximum amount of tokens the context window can hold.")
|
||||
context_window_size_current: int = Field(..., description="The current number of tokens in the context window.")
|
||||
|
||||
# context window breakdown (in messages)
|
||||
# (technically not in the context window, but useful to know)
|
||||
num_messages: int = Field(..., description="The number of messages in the context window.")
|
||||
num_archival_memory: int = Field(..., description="The number of messages in the archival memory.")
|
||||
num_recall_memory: int = Field(..., description="The number of messages in the recall memory.")
|
||||
num_tokens_external_memory_summary: int = Field(
|
||||
..., description="The number of tokens in the external memory summary (archival + recall metadata)."
|
||||
)
|
||||
|
||||
# context window breakdown (in tokens)
|
||||
# this should all add up to context_window_size_current
|
||||
|
||||
num_tokens_system: int = Field(..., description="The number of tokens in the system prompt.")
|
||||
system_prompt: str = Field(..., description="The content of the system prompt.")
|
||||
|
||||
num_tokens_core_memory: int = Field(..., description="The number of tokens in the core memory.")
|
||||
core_memory: str = Field(..., description="The content of the core memory.")
|
||||
|
||||
num_tokens_summary_memory: int = Field(..., description="The number of tokens in the summary memory.")
|
||||
summary_memory: Optional[str] = Field(None, description="The content of the summary memory.")
|
||||
|
||||
num_tokens_messages: int = Field(..., description="The number of tokens in the messages list.")
|
||||
# TODO make list of messages?
|
||||
# messages: List[dict] = Field(..., description="The messages in the context window.")
|
||||
messages: List[Message] = Field(..., description="The messages in the context window.")
|
||||
|
||||
|
||||
class Memory(BaseModel, validate_assignment=True):
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
BasicBlockMemory,
|
||||
ContextWindowOverview,
|
||||
CreateArchivalMemory,
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
@@ -51,6 +52,20 @@ def list_agents(
|
||||
return server.list_agents(user_id=actor.id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="get_agent_context_window")
|
||||
def get_agent_context_window(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
|
||||
|
||||
|
||||
@router.post("/", response_model=AgentState, operation_id="create_agent")
|
||||
def create_agent(
|
||||
agent: CreateAgent = Body(...),
|
||||
|
||||
@@ -72,7 +72,12 @@ from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
ContextWindowOverview,
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
||||
from letta.schemas.organization import Organization, OrganizationCreate
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -2019,3 +2024,13 @@ class SyncServer(Server):
|
||||
|
||||
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
|
||||
"""Add a new embedding model"""
|
||||
|
||||
def get_agent_context_window(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
) -> ContextWindowOverview:
|
||||
|
||||
# Get the current message
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
return letta_agent.get_context_window()
|
||||
|
||||
Reference in New Issue
Block a user