feat: add anthropic token counter for cloud (#2289)
This commit is contained in:
165
letta/agent.py
165
letta/agent.py
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
@@ -43,7 +45,7 @@ from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
@@ -1270,10 +1272,19 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def get_context_window_async(self) -> ContextWindowOverview:
|
||||
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION":
|
||||
return await self.get_context_window_from_anthropic_async()
|
||||
return await self.get_context_window_from_tiktoken_async()
|
||||
|
||||
async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user)
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
|
||||
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
|
||||
# Extract system, memory and external summary
|
||||
@@ -1361,7 +1372,155 @@ class Agent(BaseAgent):
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
num_archival_memory=agent_manager_passage_size,
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
external_memory_summary=external_memory_summary,
|
||||
# top-level information
|
||||
context_window_size_max=self.agent_state.llm_config.context_window,
|
||||
context_window_size_current=num_tokens_used_total,
|
||||
# context window breakdown (in tokens)
|
||||
num_tokens_system=num_tokens_system,
|
||||
system_prompt=system_prompt,
|
||||
num_tokens_core_memory=num_tokens_core_memory,
|
||||
core_memory=core_memory,
|
||||
num_tokens_summary_memory=num_tokens_summary_memory,
|
||||
summary_memory=summary_memory,
|
||||
num_tokens_messages=num_tokens_messages,
|
||||
messages=in_context_messages,
|
||||
# related to functions
|
||||
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
||||
functions_definitions=available_functions_definitions,
|
||||
)
|
||||
|
||||
async def get_context_window_from_anthropic_async(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=self.user)
|
||||
model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None
|
||||
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to anthropic dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
|
||||
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
|
||||
|
||||
# Extract system, memory and external summary
|
||||
if (
|
||||
len(in_context_messages) > 0
|
||||
and in_context_messages[0].role == MessageRole.system
|
||||
and in_context_messages[0].content
|
||||
and len(in_context_messages[0].content) == 1
|
||||
and isinstance(in_context_messages[0].content[0], TextContent)
|
||||
):
|
||||
system_message = in_context_messages[0].content[0].text
|
||||
|
||||
external_memory_marker_pos = system_message.find("###")
|
||||
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
|
||||
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
|
||||
system_prompt = system_message[:external_memory_marker_pos].strip()
|
||||
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
|
||||
core_memory = system_message[core_memory_marker_pos:].strip()
|
||||
else:
|
||||
# if no markers found, put everything in system message
|
||||
system_prompt = system_message
|
||||
external_memory_summary = None
|
||||
core_memory = None
|
||||
else:
|
||||
# if no system message, fall back on agent's system prompt
|
||||
system_prompt = self.agent_state.system
|
||||
external_memory_summary = None
|
||||
core_memory = None
|
||||
|
||||
num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}])
|
||||
num_tokens_core_memory_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}])
|
||||
if core_memory
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
num_tokens_external_memory_summary_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}])
|
||||
if external_memory_summary
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
|
||||
# Check if there's a summary message in the message queue
|
||||
if (
|
||||
len(in_context_messages) > 1
|
||||
and in_context_messages[1].role == MessageRole.user
|
||||
and in_context_messages[1].content
|
||||
and len(in_context_messages[1].content) == 1
|
||||
and isinstance(in_context_messages[1].content[0], TextContent)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
||||
):
|
||||
# Summary message exists
|
||||
text_content = in_context_messages[1].content[0].text
|
||||
assert text_content is not None
|
||||
summary_memory = text_content
|
||||
num_tokens_summary_memory_coroutine = anthropic_client.count_tokens(
|
||||
model=model, messages=[{"role": "user", "content": summary_memory}]
|
||||
)
|
||||
# with a summary message, the real messages start at index 2
|
||||
num_tokens_messages_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:])
|
||||
if len(in_context_messages_anthropic) > 2
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
|
||||
else:
|
||||
summary_memory = None
|
||||
num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0)
|
||||
# with no summary message, the real messages start at index 1
|
||||
num_tokens_messages_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:])
|
||||
if len(in_context_messages_anthropic) > 1
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
|
||||
# tokens taken up by function definitions
|
||||
if self.agent_state.tools and len(self.agent_state.tools) > 0:
|
||||
available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in self.agent_state.tools]
|
||||
num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens(
|
||||
model=model,
|
||||
tools=available_functions_definitions,
|
||||
)
|
||||
else:
|
||||
available_functions_definitions = []
|
||||
num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0)
|
||||
|
||||
(
|
||||
num_tokens_system,
|
||||
num_tokens_core_memory,
|
||||
num_tokens_external_memory_summary,
|
||||
num_tokens_summary_memory,
|
||||
num_tokens_messages,
|
||||
num_tokens_available_functions_definitions,
|
||||
) = await asyncio.gather(
|
||||
num_tokens_system_coroutine,
|
||||
num_tokens_core_memory_coroutine,
|
||||
num_tokens_external_memory_summary_coroutine,
|
||||
num_tokens_summary_memory_coroutine,
|
||||
num_tokens_messages_coroutine,
|
||||
num_tokens_available_functions_definitions_coroutine,
|
||||
)
|
||||
|
||||
num_tokens_used_total = (
|
||||
num_tokens_system # system prompt
|
||||
+ num_tokens_available_functions_definitions # function definitions
|
||||
+ num_tokens_core_memory # core memory
|
||||
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
|
||||
+ num_tokens_summary_memory # summary of ongoing conversation
|
||||
+ num_tokens_messages # tokens taken by messages
|
||||
)
|
||||
assert isinstance(num_tokens_used_total, int)
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
external_memory_summary=external_memory_summary,
|
||||
|
||||
@@ -248,6 +248,24 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return data
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int:
|
||||
client = anthropic.AsyncAnthropic()
|
||||
if messages and len(messages) == 0:
|
||||
messages = None
|
||||
if tools and len(tools) > 0:
|
||||
anthropic_tools = convert_tools_to_anthropic_format(tools)
|
||||
else:
|
||||
anthropic_tools = None
|
||||
result = await client.beta.messages.count_tokens(
|
||||
model=model or "claude-3-7-sonnet-20250219",
|
||||
messages=messages or [{"role": "user", "content": "hi"}],
|
||||
tools=anthropic_tools or [],
|
||||
)
|
||||
token_count = result.input_tokens
|
||||
if messages is None:
|
||||
token_count -= 8
|
||||
return token_count
|
||||
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
if isinstance(e, anthropic.APIConnectionError):
|
||||
logger.warning(f"[Anthropic] API connection error: {e.__cause__}")
|
||||
|
||||
@@ -202,7 +202,7 @@ async def import_agent_serialized(
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
|
||||
def retrieve_agent_context_window(
|
||||
async def retrieve_agent_context_window(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@@ -210,9 +210,12 @@ def retrieve_agent_context_window(
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
try:
|
||||
return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
class CreateAgentRequest(CreateAgent):
|
||||
|
||||
@@ -1546,6 +1546,10 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.get_context_window()
|
||||
|
||||
async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview:
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return await letta_agent.get_context_window_async()
|
||||
|
||||
def run_tool_from_source(
|
||||
self,
|
||||
actor: User,
|
||||
|
||||
Reference in New Issue
Block a user