diff --git a/letta/server/server.py b/letta/server/server.py index 8a2c38fe..af996ffc 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1342,9 +1342,6 @@ class SyncServer(Server): new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added - # rebuild system prompt and force - agent_state = await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) - # update job status job.status = JobStatus.completed job.metadata["num_passages"] = num_passages diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 953be77b..b641041b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1540,29 +1540,42 @@ class AgentManager: else: return agent_state - @trace_method + # TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish @enforce_types + @trace_method async def rebuild_system_prompt_async( - self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None - ) -> PydanticAgentState: + self, + agent_id: str, + actor: PydanticUser, + force=False, + update_timestamp=True, + tool_rules_solver: Optional[ToolRulesSolver] = None, + dry_run: bool = False, + ) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]: """Rebuilds the system message with the latest memory object and any shared memory block updates Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages """ - # Get the current agent state - agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources"], actor=actor) + num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id) + num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) + agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor) + + num_messages, num_archival_memories, agent_state = await asyncio.gather( + num_messages_task, + num_archival_memories_task, + agent_state_task, + ) + if not tool_rules_solver: tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) - curr_system_message = await self.get_system_message_async( - agent_id=agent_id, actor=actor - ) # this is the system + memory bank, not just the system prompt + curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) if curr_system_message is None: logger.warning(f"No system message found for agent {agent_state.id} and user {actor}") - return agent_state + return agent_state, curr_system_message, num_messages, num_archival_memories curr_system_message_openai = curr_system_message.to_openai_dict() @@ -1576,7 +1589,7 @@ class AgentManager: logger.debug( f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild" ) - return agent_state + return agent_state, curr_system_message, num_messages, num_archival_memories # If the memory didn't update, we probably don't want to update the timestamp inside # For example, if we're doing a system prompt swap, this should probably be False @@ -1586,9 +1599,6 @@ class AgentManager: # NOTE: a bit of a hack - we pull the timestamp from the message created_by memory_edit_timestamp = curr_system_message.created_at - num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id) - num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) - # update memory (TODO: potentially update recall/archival stats separately) new_system_message_str = compile_system_message( @@ -1607,19 +1617,23 @@ class AgentManager: logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") # Swap the system message out (only if there is a diff) - message = PydanticMessage.dict_to_message( + temp_message = PydanticMessage.dict_to_message( agent_id=agent_id, model=agent_state.llm_config.model, openai_message_dict={"role": "system", "content": new_system_message_str}, ) - message = await self.message_manager.update_message_by_id_async( - message_id=curr_system_message.id, - message_update=MessageUpdate(**message.model_dump()), - actor=actor, - ) - return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor) - else: - return agent_state + temp_message.id = curr_system_message.id + + if not dry_run: + await self.message_manager.update_message_by_id_async( + message_id=curr_system_message.id, + message_update=MessageUpdate(**temp_message.model_dump()), + actor=actor, + ) + else: + curr_system_message = temp_message + + return agent_state, curr_system_message, num_messages, num_archival_memories @trace_method @enforce_types @@ -1781,7 +1795,7 @@ class AgentManager: # NOTE: don't do this since re-buildin the memory is handled at the start of the step # rebuild memory - this records the last edited timestamp of the memory # TODO: pass in update timestamp from block edit time - agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor) + await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor) return agent_state @@ -1845,12 +1859,8 @@ class AgentManager: ) # Commit the changes - await agent.update_async(session, actor=actor) - - # Force rebuild of system prompt so that the agent is updated with passage count - pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) - - return pydantic_agent + agent = await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() @trace_method @enforce_types @@ -2761,8 +2771,11 @@ class AgentManager: results = [row[0] for row in result.all()] return results + @trace_method async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: - agent_state = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) + agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async( + agent_id=agent_id, actor=actor, force=True, dry_run=True + ) calculator = ContextWindowCalculator() if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION" or agent_state.llm_config.model_endpoint_type == "anthropic": @@ -2778,5 +2791,7 @@ class AgentManager: actor=actor, token_counter=token_counter, message_manager=self.message_manager, - passage_manager=self.passage_manager, + system_message_compiled=system_message, + num_archival_memories=num_archival_memories, + num_messages=num_messages, ) diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py index b3a89028..c405d289 100644 --- a/letta/services/context_window_calculator/context_window_calculator.py +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -8,8 +8,10 @@ from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview +from letta.schemas.message import Message from letta.schemas.user import User as PydanticUser from letta.services.context_window_calculator.token_counter import TokenCounter +from letta.services.message_manager import MessageManager logger = get_logger(__name__) @@ -57,16 +59,18 @@ class ContextWindowCalculator: return None, 1 async def calculate_context_window( - self, agent_state: AgentState, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any + self, + agent_state: AgentState, + actor: PydanticUser, + token_counter: TokenCounter, + message_manager: MessageManager, + system_message_compiled: Message, + num_archival_memories: int, + num_messages: int, ) -> ContextWindowOverview: """Calculate context window information using the provided token counter""" - - # Fetch data concurrently - (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), - passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id), - message_manager.size_async(actor=actor, agent_id=agent_state.id), - ) + messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids[1:], actor=actor) + in_context_messages = [system_message_compiled] + messages # Convert messages to appropriate format converted_messages = token_counter.convert_messages(in_context_messages) @@ -129,8 +133,8 @@ class ContextWindowCalculator: return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(in_context_messages), - num_archival_memory=passage_manager_size, - num_recall_memory=message_manager_size, + num_archival_memory=num_archival_memories, + num_recall_memory=num_messages, num_tokens_external_memory_summary=num_tokens_external_memory_summary, external_memory_summary=external_memory_summary, # top-level information diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 1ec4a3fc..52e43244 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient +from letta.otel.tracing import trace_method from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.utils import count_tokens @@ -36,6 +37,7 @@ class AnthropicTokenCounter(TokenCounter): self.client = anthropic_client self.model = model + @trace_method @async_redis_cache( key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", prefix="token_counter", @@ -46,6 +48,7 @@ class AnthropicTokenCounter(TokenCounter): return 0 return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}]) + @trace_method @async_redis_cache( key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", prefix="token_counter", @@ -56,6 +59,7 @@ class AnthropicTokenCounter(TokenCounter): return 0 return await self.client.count_tokens(model=self.model, messages=messages) + @trace_method @async_redis_cache( key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", prefix="token_counter", @@ -76,6 +80,7 @@ class TiktokenCounter(TokenCounter): def __init__(self, model: str): self.model = model + @trace_method @async_redis_cache( key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", prefix="token_counter", @@ -86,6 +91,7 @@ class TiktokenCounter(TokenCounter): return 0 return count_tokens(text) + @trace_method @async_redis_cache( key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", prefix="token_counter", @@ -98,6 +104,7 @@ class TiktokenCounter(TokenCounter): return num_tokens_from_messages(messages=messages, model=self.model) + @trace_method @async_redis_cache( key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", prefix="token_counter", diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 82abb3a7..950e44d0 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -237,7 +237,7 @@ def validate_context_window_overview( # 16. Check attached file is visible if attached_file: - assert attached_file.visible_content in overview.core_memory + assert attached_file.visible_content in overview.core_memory, "File must be attached in core memory" assert '" in overview.core_memory