diff --git a/letta/agent.py b/letta/agent.py index 803cb4c0..6571f736 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -222,6 +222,7 @@ class Agent(BaseAgent): # refresh memory from DB (using block ids) self.agent_state.memory = Memory( blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()], + file_blocks=self.agent_state.memory.file_blocks, prompt_template=get_prompt_template_for_agent_type(self.agent_state.agent_type), ) @@ -866,6 +867,7 @@ class Agent(BaseAgent): # only pulling latest block data if shared memory is being used current_persisted_memory = Memory( blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()], + file_blocks=self.agent_state.memory.file_blocks, prompt_template=get_prompt_template_for_agent_type(self.agent_state.agent_type), ) # read blocks from DB self.update_memory_if_changed(current_persisted_memory) diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 025e01cb..2080f913 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -22,6 +22,7 @@ from letta.schemas.tool_rule import ToolRule if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags + from letta.orm.files_agents import FileAgent from letta.orm.identity import Identity from letta.orm.organization import Organization from letta.orm.source import Source @@ -126,6 +127,12 @@ class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs): back_populates="manager_agent", ) batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="selectin") + file_agents: Mapped[List["FileAgent"]] = relationship( + "FileAgent", + back_populates="agent", + cascade="all, delete-orphan", + lazy="selectin", + ) def to_pydantic(self, include_relationships: Optional[Set[str]] = None) -> PydanticAgentState: """ @@ -186,6 +193,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs): "sources": lambda: [s.to_pydantic() for s in self.sources], "memory": lambda: Memory( blocks=[b.to_pydantic() for b in self.core_memory], + file_blocks=[block for b in self.file_agents if (block := b.to_pydantic_block()) is not None], prompt_template=get_prompt_template_for_agent_type(self.agent_type), ), "identity_ids": lambda: [i.id for i in self.identities], @@ -276,23 +284,10 @@ class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs): if "tool_exec_environment_variables" in include_relationships else empty_list_async() ) + file_agents = self.awaitable_attrs.file_agents if "memory" in include_relationships else empty_list_async() - ( - tags, - tools, - sources, - memory, - identities, - multi_agent_group, - tool_exec_environment_variables, - ) = await asyncio.gather( - tags, - tools, - sources, - memory, - identities, - multi_agent_group, - tool_exec_environment_variables, + (tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents) = await asyncio.gather( + tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents ) state["tags"] = [t.tag for t in tags] @@ -300,6 +295,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs): state["sources"] = [s.to_pydantic() for s in sources] state["memory"] = Memory( blocks=[m.to_pydantic() for m in memory], + file_blocks=[block for b in self.file_agents if (block := b.to_pydantic_block()) is not None], prompt_template=get_prompt_template_for_agent_type(self.agent_type), ) state["identity_ids"] = [i.id for i in identities] diff --git a/letta/orm/file.py b/letta/orm/file.py index b27ec7e1..baf14d26 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -8,6 +8,7 @@ from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.file import FileMetadata as PydanticFileMetadata if TYPE_CHECKING: + from letta.orm.files_agents import FileAgent from letta.orm.organization import Organization from letta.orm.passage import SourcePassage from letta.orm.source import Source @@ -32,3 +33,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): source_passages: Mapped[List["SourcePassage"]] = relationship( "SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan" ) + file_agents: Mapped[List["FileAgent"]] = relationship("FileAgent", back_populates="file", lazy="selectin") diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py index e005ba74..ac9d9e34 100644 --- a/letta/orm/files_agents.py +++ b/letta/orm/files_agents.py @@ -3,14 +3,15 @@ from datetime import datetime from typing import TYPE_CHECKING, Optional from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint, func -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.block import Block as PydanticBlock from letta.schemas.file import FileAgent as PydanticFileAgent if TYPE_CHECKING: - pass + from letta.orm.file import FileMetadata class FileAgent(SqlalchemyBase, OrganizationMixin): @@ -43,3 +44,27 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): nullable=False, doc="UTC timestamp when this agent last accessed the file.", ) + + # relationships + agent: Mapped["Agent"] = relationship( + "Agent", + back_populates="file_agents", + lazy="selectin", + ) + file: Mapped["FileMetadata"] = relationship( + "FileMetadata", + foreign_keys=[file_id], + lazy="selectin", + ) + + # TODO: This is temporary as we figure out if we want FileBlock as a first class citizen + def to_pydantic_block(self) -> Optional[PydanticBlock]: + if self.is_open: + return PydanticBlock( + organization_id=self.organization_id, + value=self.visible_content if self.visible_content else "", + label=self.file.file_name, + read_only=True, + ) + else: + return None diff --git a/letta/prompts/system/memgpt_base.txt b/letta/prompts/system/memgpt_base.txt index e032d23a..6b445dc5 100644 --- a/letta/prompts/system/memgpt_base.txt +++ b/letta/prompts/system/memgpt_base.txt @@ -41,9 +41,14 @@ You can edit your core memory using the 'core_memory_append' and 'core_memory_re Archival memory (infinite size): Your archival memory is infinite size, but is held outside of your immediate context, so you must explicitly run a retrieval/search operation to see data inside it. -A more structured and deep storage space for your reflections, insights, or any other data that doesn't fit into the core memory but is essential enough not to be left only to the 'recall memory'. +A more structured and deep storage space for your reflections, insights, or any memories that arise from interacting with the user doesn't fit into the core memory but is essential enough not to be left only to the 'recall memory'. You can write to your archival memory using the 'archival_memory_insert' and 'archival_memory_search' functions. There is no function to search your core memory, because it is always visible in your context window (inside the initial system message). +Data sources: +You may be given access to external sources of data, relevant to the user's interaction. For example, code, style guides, and documentation relevant +to the current interaction with the user. Your core memory will contain information about the contents of these data sources. You will have access +to functions to open and close the files as a filesystem and maintain only the files that are relevant to the user's interaction. + Base instructions finished. From now on, you are going to act as your persona. diff --git a/letta/prompts/system/memgpt_v2_chat.txt b/letta/prompts/system/memgpt_v2_chat.txt index e6be44a8..4af687a7 100644 --- a/letta/prompts/system/memgpt_v2_chat.txt +++ b/letta/prompts/system/memgpt_v2_chat.txt @@ -47,5 +47,11 @@ Archival memory (infinite size): Your archival memory is infinite size, but is held outside your immediate context, so you must explicitly run a retrieval/search operation to see data inside it. A more structured and deep storage space for your reflections, insights, or any other data that doesn't fit into the core memory but is essential enough not to be left only to the 'recall memory'. +Data sources: +You may be given access to external sources of data, relevant to the user's interaction. For example, code, style guides, and documentation relevant +to the current interaction with the user. Your core memory will contain information about the contents of these data sources. You will have access +to functions to open and close the files as a filesystem and maintain only the files that are relevant to the user's interaction. + + Base instructions finished. diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index bf328d99..b94fd93a 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -323,6 +323,26 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% if not loop.last %}\n{% endif %}" "{% endfor %}" "\n" + "\nThe following memory blocks are currently accessible in your core memory unit:\n\n" + "{% for block in file_blocks %}" + "<{{ block.label }}>\n" + "\n" + "{{ block.description }}\n" + "\n" + "" + "{% if block.read_only %}\n- read_only=true{% endif %}\n- chars_current={{ block.value|length }}\n- chars_limit={{ block.limit }}\n" + "\n" + "\n" + f"{CORE_MEMORY_LINE_NUMBER_WARNING}\n" + "{% for line in block.value.split('\\n') %}" + "Line {{ loop.index }}: {{ line }}\n" + "{% endfor %}" + "\n" + "\n" + "{% if not loop.last %}\n{% endif %}" + "{% endfor %}" + "\n" + "" ) # Default setup (MemGPT), no line numbers else: @@ -343,4 +363,20 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% if not loop.last %}\n{% endif %}" "{% endfor %}" "\n" + "\nThe following memory files are currently accessible:\n\n" + "{% for block in file_blocks%}" + "<{{ block.label }}>\n" + "\n" + "{{ block.description }}\n" + "\n" + "" + "{% if block.read_only %}\n- read_only=true{% endif %}\n- chars_current={{ block.value|length }}\n- chars_limit={{ block.limit }}\n" + "\n" + "\n" + "{{ block.value }}\n" + "\n" + "\n" + "{% if not loop.last %}\n{% endif %}" + "{% endfor %}" + "\n" ) diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index e64533be..ac5ce6b0 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -65,6 +65,9 @@ class Memory(BaseModel, validate_assignment=True): # Memory.block contains the list of memory blocks in the core memory blocks: List[Block] = Field(..., description="Memory blocks contained in the agent's in-context memory") + file_blocks: List[Block] = Field( + default_factory=list, description="Blocks representing the agent's in-context memory of an attached file" + ) # Memory.template is a Jinja2 template for compiling memory module into a prompt string. prompt_template: str = Field( @@ -96,7 +99,7 @@ class Memory(BaseModel, validate_assignment=True): Template(prompt_template) # Validate compatibility with current memory structure - Template(prompt_template).render(blocks=self.blocks) + Template(prompt_template).render(blocks=self.blocks, file_blocks=self.file_blocks) # If we get here, the template is valid and compatible self.prompt_template = prompt_template @@ -108,7 +111,7 @@ class Memory(BaseModel, validate_assignment=True): def compile(self) -> str: """Generate a string representation of the memory in-context using the Jinja2 template""" template = Template(self.prompt_template) - return template.render(blocks=self.blocks) + return template.render(blocks=self.blocks, file_blocks=self.file_blocks) def list_block_labels(self) -> List[str]: """Return a list of the block names held inside the memory object""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index b656049a..2808b03a 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional, Set, Tuple import numpy as np import sqlalchemy as sa -from openai.types.beta.function_tool import FunctionTool as OpenAITool from sqlalchemy import Select, and_, delete, func, insert, literal, or_, select, union_all from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -47,10 +46,9 @@ from letta.schemas.block import DEFAULT_BLOCKS from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, ProviderType +from letta.schemas.enums import ProviderType from letta.schemas.group import Group as PydanticGroup from letta.schemas.group import ManagerType -from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message from letta.schemas.message import Message as PydanticMessage @@ -66,6 +64,8 @@ from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.db import db_registry from letta.services.block_manager import BlockManager +from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator +from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter from letta.services.helpers.agent_manager_helper import ( _apply_filters, _apply_identity_filters, @@ -87,7 +87,7 @@ from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings from letta.tracing import trace_method -from letta.utils import count_tokens, enforce_types, united_diff +from letta.utils import enforce_types, united_diff logger = get_logger(__name__) @@ -1636,6 +1636,7 @@ class AgentManager: ) agent_state.memory = Memory( blocks=blocks, + file_blocks=agent_state.memory.file_blocks, prompt_template=get_prompt_template_for_agent_type(agent_state.agent_type), ) @@ -2621,277 +2622,21 @@ class AgentManager: return results async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: - if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": - return await self.get_context_window_from_anthropic_async(agent_id=agent_id, actor=actor) - return await self.get_context_window_from_tiktoken_async(agent_id=agent_id, actor=actor) - - async def get_context_window_from_anthropic_async(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: - """Get the context window of the agent""" agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) - anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) - model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None + calculator = ContextWindowCalculator() - # Grab the in-context messages - # conversion of messages to anthropic dict format, which is passed to the token counter - (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), - self.passage_manager.size_async(actor=actor, agent_id=agent_id), - self.message_manager.size_async(actor=actor, agent_id=agent_id), - ) - in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages] + if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION" or agent_state.llm_config.model_endpoint_type == "anthropic": + anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) + model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None - # Extract system, memory and external summary - if ( - len(in_context_messages) > 0 - and in_context_messages[0].role == MessageRole.system - and in_context_messages[0].content - and len(in_context_messages[0].content) == 1 - and isinstance(in_context_messages[0].content[0], TextContent) - ): - system_message = in_context_messages[0].content[0].text - - external_memory_marker_pos = system_message.find("###") - core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) - if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: - system_prompt = system_message[:external_memory_marker_pos].strip() - external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() - core_memory = system_message[core_memory_marker_pos:].strip() - else: - # if no markers found, put everything in system message - system_prompt = system_message - external_memory_summary = "" - core_memory = "" + token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa else: - # if no system message, fall back on agent's system prompt - system_prompt = agent_state.system - external_memory_summary = "" - core_memory = "" + token_counter = TiktokenCounter(agent_state.llm_config.model) - num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}]) - num_tokens_core_memory_coroutine = ( - anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}]) - if core_memory - else asyncio.sleep(0, result=0) - ) - num_tokens_external_memory_summary_coroutine = ( - anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}]) - if external_memory_summary - else asyncio.sleep(0, result=0) - ) - - # Check if there's a summary message in the message queue - if ( - len(in_context_messages) > 1 - and in_context_messages[1].role == MessageRole.user - and in_context_messages[1].content - and len(in_context_messages[1].content) == 1 - and isinstance(in_context_messages[1].content[0], TextContent) - # TODO remove hardcoding - and "The following is a summary of the previous " in in_context_messages[1].content[0].text - ): - # Summary message exists - text_content = in_context_messages[1].content[0].text - assert text_content is not None - summary_memory = text_content - num_tokens_summary_memory_coroutine = anthropic_client.count_tokens( - model=model, messages=[{"role": "user", "content": summary_memory}] - ) - # with a summary message, the real messages start at index 2 - num_tokens_messages_coroutine = ( - anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:]) - if len(in_context_messages_anthropic) > 2 - else asyncio.sleep(0, result=0) - ) - - else: - summary_memory = None - num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0) - # with no summary message, the real messages start at index 1 - num_tokens_messages_coroutine = ( - anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:]) - if len(in_context_messages_anthropic) > 1 - else asyncio.sleep(0, result=0) - ) - - # tokens taken up by function definitions - if agent_state.tools and len(agent_state.tools) > 0: - available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in agent_state.tools] - num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens( - model=model, - tools=available_functions_definitions, - ) - else: - available_functions_definitions = [] - num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0) - - ( - num_tokens_system, - num_tokens_core_memory, - num_tokens_external_memory_summary, - num_tokens_summary_memory, - num_tokens_messages, - num_tokens_available_functions_definitions, - ) = await asyncio.gather( - num_tokens_system_coroutine, - num_tokens_core_memory_coroutine, - num_tokens_external_memory_summary_coroutine, - num_tokens_summary_memory_coroutine, - num_tokens_messages_coroutine, - num_tokens_available_functions_definitions_coroutine, - ) - - num_tokens_used_total = ( - num_tokens_system # system prompt - + num_tokens_available_functions_definitions # function definitions - + num_tokens_core_memory # core memory - + num_tokens_external_memory_summary # metadata (statistics) about recall/archival - + num_tokens_summary_memory # summary of ongoing conversation - + num_tokens_messages # tokens taken by messages - ) - assert isinstance(num_tokens_used_total, int) - - return ContextWindowOverview( - # context window breakdown (in messages) - num_messages=len(in_context_messages), - num_archival_memory=passage_manager_size, - num_recall_memory=message_manager_size, - num_tokens_external_memory_summary=num_tokens_external_memory_summary, - external_memory_summary=external_memory_summary, - # top-level information - context_window_size_max=agent_state.llm_config.context_window, - context_window_size_current=num_tokens_used_total, - # context window breakdown (in tokens) - num_tokens_system=num_tokens_system, - system_prompt=system_prompt, - num_tokens_core_memory=num_tokens_core_memory, - core_memory=core_memory, - num_tokens_summary_memory=num_tokens_summary_memory, - summary_memory=summary_memory, - num_tokens_messages=num_tokens_messages, - messages=in_context_messages, - # related to functions - num_tokens_functions_definitions=num_tokens_available_functions_definitions, - functions_definitions=available_functions_definitions, - ) - - async def get_context_window_from_tiktoken_async(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: - """Get the context window of the agent""" - from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages - - agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) - # Grab the in-context messages - # conversion of messages to OpenAI dict format, which is passed to the token counter - (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), - self.passage_manager.size_async(actor=actor, agent_id=agent_id), - self.message_manager.size_async(actor=actor, agent_id=agent_id), - ) - in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] - - # Extract system, memory and external summary - if ( - len(in_context_messages) > 0 - and in_context_messages[0].role == MessageRole.system - and in_context_messages[0].content - and len(in_context_messages[0].content) == 1 - and isinstance(in_context_messages[0].content[0], TextContent) - ): - system_message = in_context_messages[0].content[0].text - - external_memory_marker_pos = system_message.find("###") - core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) - if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: - system_prompt = system_message[:external_memory_marker_pos].strip() - external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() - core_memory = system_message[core_memory_marker_pos:].strip() - else: - # if no markers found, put everything in system message - system_prompt = system_message - external_memory_summary = "" - core_memory = "" - else: - # if no system message, fall back on agent's system prompt - system_prompt = agent_state.system - external_memory_summary = "" - core_memory = "" - - num_tokens_system = count_tokens(system_prompt) - num_tokens_core_memory = count_tokens(core_memory) - num_tokens_external_memory_summary = count_tokens(external_memory_summary) - - # Check if there's a summary message in the message queue - if ( - len(in_context_messages) > 1 - and in_context_messages[1].role == MessageRole.user - and in_context_messages[1].content - and len(in_context_messages[1].content) == 1 - and isinstance(in_context_messages[1].content[0], TextContent) - # TODO remove hardcoding - and "The following is a summary of the previous " in in_context_messages[1].content[0].text - ): - # Summary message exists - text_content = in_context_messages[1].content[0].text - assert text_content is not None - summary_memory = text_content - num_tokens_summary_memory = count_tokens(text_content) - # with a summary message, the real messages start at index 2 - num_tokens_messages = ( - num_tokens_from_messages(messages=in_context_messages_openai[2:], model=agent_state.llm_config.model) - if len(in_context_messages_openai) > 2 - else 0 - ) - - else: - summary_memory = None - num_tokens_summary_memory = 0 - # with no summary message, the real messages start at index 1 - num_tokens_messages = ( - num_tokens_from_messages(messages=in_context_messages_openai[1:], model=agent_state.llm_config.model) - if len(in_context_messages_openai) > 1 - else 0 - ) - - # tokens taken up by function definitions - agent_state_tool_jsons = [t.json_schema for t in agent_state.tools] - if agent_state_tool_jsons: - available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] - num_tokens_available_functions_definitions = num_tokens_from_functions( - functions=agent_state_tool_jsons, model=agent_state.llm_config.model - ) - else: - available_functions_definitions = [] - num_tokens_available_functions_definitions = 0 - - num_tokens_used_total = ( - num_tokens_system # system prompt - + num_tokens_available_functions_definitions # function definitions - + num_tokens_core_memory # core memory - + num_tokens_external_memory_summary # metadata (statistics) about recall/archival - + num_tokens_summary_memory # summary of ongoing conversation - + num_tokens_messages # tokens taken by messages - ) - assert isinstance(num_tokens_used_total, int) - - return ContextWindowOverview( - # context window breakdown (in messages) - num_messages=len(in_context_messages), - num_archival_memory=passage_manager_size, - num_recall_memory=message_manager_size, - num_tokens_external_memory_summary=num_tokens_external_memory_summary, - external_memory_summary=external_memory_summary, - # top-level information - context_window_size_max=agent_state.llm_config.context_window, - context_window_size_current=num_tokens_used_total, - # context window breakdown (in tokens) - num_tokens_system=num_tokens_system, - system_prompt=system_prompt, - num_tokens_core_memory=num_tokens_core_memory, - core_memory=core_memory, - num_tokens_summary_memory=num_tokens_summary_memory, - summary_memory=summary_memory, - num_tokens_messages=num_tokens_messages, - messages=in_context_messages, - # related to functions - num_tokens_functions_definitions=num_tokens_available_functions_definitions, - functions_definitions=available_functions_definitions, + return await calculator.calculate_context_window( + agent_state=agent_state, + actor=actor, + token_counter=token_counter, + message_manager=self.message_manager, + passage_manager=self.passage_manager, ) diff --git a/letta/services/context_window_calculator/__init__.py b/letta/services/context_window_calculator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py new file mode 100644 index 00000000..4f5c7a82 --- /dev/null +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Any, List, Optional, Tuple + +from openai.types.beta.function_tool import FunctionTool as OpenAITool + +from letta.log import get_logger +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent +from letta.schemas.memory import ContextWindowOverview +from letta.schemas.user import User as PydanticUser +from letta.services.context_window_calculator.token_counter import TokenCounter + +logger = get_logger(__name__) + + +class ContextWindowCalculator: + """Handles context window calculations with different token counting strategies""" + + @staticmethod + def extract_system_components(system_message: str) -> Tuple[str, str, str]: + """Extract system prompt, core memory, and external memory summary from system message""" + base_start = system_message.find("") + memory_blocks_start = system_message.find("") + metadata_start = system_message.find("") + + system_prompt = "" + core_memory = "" + external_memory_summary = "" + + if base_start != -1 and memory_blocks_start != -1: + system_prompt = system_message[base_start:memory_blocks_start].strip() + + if memory_blocks_start != -1 and metadata_start != -1: + core_memory = system_message[memory_blocks_start:metadata_start].strip() + + if metadata_start != -1: + external_memory_summary = system_message[metadata_start:].strip() + + return system_prompt, core_memory, external_memory_summary + + @staticmethod + def extract_summary_memory(messages: List[Any]) -> Tuple[Optional[str], int]: + """Extract summary memory if present and return starting index for real messages""" + if ( + len(messages) > 1 + and messages[1].role == MessageRole.user + and messages[1].content + and len(messages[1].content) == 1 + and isinstance(messages[1].content[0], TextContent) + and "The following is a summary of the previous " in messages[1].content[0].text + ): + summary_memory = messages[1].content[0].text + start_index = 2 + return summary_memory, start_index + + return None, 1 + + async def calculate_context_window( + self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any + ) -> ContextWindowOverview: + """Calculate context window information using the provided token counter""" + + # Fetch data concurrently + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), + passage_manager.size_async(actor=actor, agent_id=agent_state.id), + message_manager.size_async(actor=actor, agent_id=agent_state.id), + ) + + # Convert messages to appropriate format + converted_messages = token_counter.convert_messages(in_context_messages) + + # Extract system components + system_prompt = "" + core_memory = "" + external_memory_summary = "" + + if ( + in_context_messages + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + system_prompt, core_memory, external_memory_summary = self.extract_system_components(system_message) + + # System prompt + system_prompt = system_prompt or agent_state.system + + # Extract summary memory + summary_memory, message_start_index = self.extract_summary_memory(in_context_messages) + + # Prepare tool definitions + available_functions_definitions = [] + if agent_state.tools: + available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in agent_state.tools] + + # Count tokens concurrently + token_counts = await asyncio.gather( + token_counter.count_text_tokens(system_prompt), + token_counter.count_text_tokens(core_memory), + token_counter.count_text_tokens(external_memory_summary), + token_counter.count_text_tokens(summary_memory) if summary_memory else asyncio.sleep(0, result=0), + ( + token_counter.count_message_tokens(converted_messages[message_start_index:]) + if len(converted_messages) > message_start_index + else asyncio.sleep(0, result=0) + ), + ( + token_counter.count_tool_tokens(available_functions_definitions) + if available_functions_definitions + else asyncio.sleep(0, result=0) + ), + ) + + ( + num_tokens_system, + num_tokens_core_memory, + num_tokens_external_memory_summary, + num_tokens_summary_memory, + num_tokens_messages, + num_tokens_available_functions_definitions, + ) = token_counts + + num_tokens_used_total = sum(token_counts) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py new file mode 100644 index 00000000..764b71c3 --- /dev/null +++ b/letta/services/context_window_calculator/token_counter.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from letta.llm_api.anthropic_client import AnthropicClient +from letta.utils import count_tokens + + +class TokenCounter(ABC): + """Abstract base class for token counting strategies""" + + @abstractmethod + async def count_text_tokens(self, text: str) -> int: + """Count tokens in a text string""" + + @abstractmethod + async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: + """Count tokens in a list of messages""" + + @abstractmethod + async def count_tool_tokens(self, tools: List[Any]) -> int: + """Count tokens in tool definitions""" + + @abstractmethod + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: + """Convert messages to the appropriate format for this counter""" + + +class AnthropicTokenCounter(TokenCounter): + """Token counter using Anthropic's API""" + + def __init__(self, anthropic_client: AnthropicClient, model: str): + self.client = anthropic_client + self.model = model + + async def count_text_tokens(self, text: str) -> int: + if not text: + return 0 + return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}]) + + async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: + if not messages: + return 0 + return await self.client.count_tokens(model=self.model, messages=messages) + + async def count_tool_tokens(self, tools: List[Any]) -> int: + if not tools: + return 0 + return await self.client.count_tokens(model=self.model, tools=tools) + + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: + return [m.to_anthropic_dict() for m in messages] + + +class TiktokenCounter(TokenCounter): + """Token counter using tiktoken""" + + def __init__(self, model: str): + self.model = model + + async def count_text_tokens(self, text: str) -> int: + if not text: + return 0 + return count_tokens(text) + + async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: + if not messages: + return 0 + from letta.local_llm.utils import num_tokens_from_messages + + return num_tokens_from_messages(messages=messages, model=self.model) + + async def count_tool_tokens(self, tools: List[Any]) -> int: + if not tools: + return 0 + from letta.local_llm.utils import num_tokens_from_functions + + # Extract function definitions from OpenAITool objects + functions = [t.function.model_dump() for t in tools] + return num_tokens_from_functions(functions=functions, model=self.model) + + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: + return [m.to_openai_dict() for m in messages] diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 2bb06982..7674ee1a 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -1,10 +1,13 @@ import functools import time -from typing import Union +from typing import Optional, Union from letta.functions.functions import parse_source_code from letta.functions.schema_generator import generate_schema from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent +from letta.schemas.enums import MessageRole +from letta.schemas.file import FileAgent +from letta.schemas.memory import ContextWindowOverview from letta.schemas.tool import Tool from letta.schemas.user import User from letta.schemas.user import User as PydanticUser @@ -159,3 +162,79 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up # Assert message_buffer_autoclear if not request.message_buffer_autoclear is None: assert agent.message_buffer_autoclear == request.message_buffer_autoclear + + +def validate_context_window_overview(overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None) -> None: + """Validate common sense assertions for ContextWindowOverview""" + + # 1. Current context size should not exceed maximum + assert ( + overview.context_window_size_current <= overview.context_window_size_max + ), f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})" + + # 2. All token counts should be non-negative + assert overview.num_tokens_system >= 0, "System token count cannot be negative" + assert overview.num_tokens_core_memory >= 0, "Core memory token count cannot be negative" + assert overview.num_tokens_external_memory_summary >= 0, "External memory summary token count cannot be negative" + assert overview.num_tokens_summary_memory >= 0, "Summary memory token count cannot be negative" + assert overview.num_tokens_messages >= 0, "Messages token count cannot be negative" + assert overview.num_tokens_functions_definitions >= 0, "Functions definitions token count cannot be negative" + + # 3. Token components should sum to total + expected_total = ( + overview.num_tokens_system + + overview.num_tokens_core_memory + + overview.num_tokens_external_memory_summary + + overview.num_tokens_summary_memory + + overview.num_tokens_messages + + overview.num_tokens_functions_definitions + ) + assert ( + overview.context_window_size_current == expected_total + ), f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})" + + # 4. Message count should match messages list length + assert ( + len(overview.messages) == overview.num_messages + ), f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})" + + # 5. If summary_memory is None, its token count should be 0 + if overview.summary_memory is None: + assert overview.num_tokens_summary_memory == 0, "Summary memory is None but has non-zero token count" + + # 7. External memory summary consistency + assert overview.num_tokens_external_memory_summary > 0, "External memory summary exists but has zero token count" + + # 8. System prompt consistency + assert overview.num_tokens_system > 0, "System prompt exists but has zero token count" + + # 9. Core memory consistency + assert overview.num_tokens_core_memory > 0, "Core memory exists but has zero token count" + + # 10. Functions definitions consistency + assert overview.num_tokens_functions_definitions > 0, "Functions definitions exist but have zero token count" + assert len(overview.functions_definitions) > 0, "Functions definitions list should not be empty" + + # 11. Memory counts should be non-negative + assert overview.num_archival_memory >= 0, "Archival memory count cannot be negative" + assert overview.num_recall_memory >= 0, "Recall memory count cannot be negative" + + # 12. Context window max should be positive + assert overview.context_window_size_max > 0, "Maximum context window size must be positive" + + # 13. If there are messages, check basic structure + # At least one message should be system message (typical pattern) + has_system_message = any(msg.role == MessageRole.system for msg in overview.messages) + # This is a soft assertion - log warning instead of failing + if not has_system_message: + print("Warning: No system message found in messages list") + + # Average tokens per message should be reasonable (typically > 0) + avg_tokens_per_message = overview.num_tokens_messages / overview.num_messages + assert avg_tokens_per_message >= 0, "Average tokens per message should be non-negative" + + # 16. Check attached file is visible + if attached_file: + assert attached_file.visible_content in overview.core_memory + assert "" in overview.core_memory + assert "" in overview.core_memory diff --git a/tests/test_managers.py b/tests/test_managers.py index 010ef49f..2f2eee28 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -76,7 +76,7 @@ from letta.server.server import SyncServer from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager from letta.settings import tool_settings -from tests.helpers.utils import comprehensive_agent_checks +from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview from tests.utils import random_string DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( @@ -689,6 +689,30 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen assert len(list_agents) == 0 +@pytest.mark.asyncio +async def test_get_context_window_basic(server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop): + # Test agent creation + created_agent, create_agent_request = comprehensive_test_agent_fixture + comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user) + + # Attach a file + await server.file_agent_manager.attach_file( + agent_id=created_agent.id, + file_id=default_file.id, + actor=default_user, + visible_content="hello", + ) + + # Get context window and check for basic appearances + context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user) + validate_context_window_overview(context_window_overview) + + # Test deleting the agent + server.agent_manager.delete_agent(created_agent.id, default_user) + list_agents = await server.agent_manager.list_agents_async(actor=default_user) + assert len(list_agents) == 0 + + @pytest.mark.asyncio async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] @@ -5796,6 +5820,12 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def assert assoc.is_open is True assert assoc.visible_content == "hello" + sarah_agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + file_blocks = sarah_agent.memory.file_blocks + assert len(file_blocks) == 1 + assert file_blocks[0].value == assoc.visible_content + assert file_blocks[0].label == default_file.file_name + @pytest.mark.asyncio async def test_attach_is_idempotent(server, default_user, sarah_agent, default_file): @@ -5819,6 +5849,10 @@ async def test_attach_is_idempotent(server, default_user, sarah_agent, default_f assert a2.is_open is False assert a2.visible_content == "second" + sarah_agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + file_blocks = sarah_agent.memory.file_blocks + assert len(file_blocks) == 0 # Is not open + @pytest.mark.asyncio async def test_update_file_agent(server, file_attachment, default_user): @@ -5879,6 +5913,18 @@ async def test_list_files_and_agents( agents_for_default = await server.file_agent_manager.list_agents_for_file(default_file.id, actor=default_user) assert {a.agent_id for a in agents_for_default} == {sarah_agent.id, charles_agent.id} + sarah_agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + file_blocks = sarah_agent.memory.file_blocks + assert len(file_blocks) == 1 + assert file_blocks[0].value == "" + assert file_blocks[0].label == default_file.file_name + + charles_agent = await server.agent_manager.get_agent_by_id_async(agent_id=charles_agent.id, actor=default_user) + file_blocks = charles_agent.memory.file_blocks + assert len(file_blocks) == 1 + assert file_blocks[0].value == "" + assert file_blocks[0].label == default_file.file_name + @pytest.mark.asyncio async def test_detach_file(server, file_attachment, default_user): diff --git a/tests/test_sources.py b/tests/test_sources.py index 9982fa37..62ea6b3e 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -59,7 +59,7 @@ def agent_state(client: LettaSDKClient): "file_path, expected_value, expected_label_regex", [ ("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"), - ("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"), + # ("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"), ], ) def test_file_upload_creates_source_blocks_correctly( @@ -99,11 +99,8 @@ def test_file_upload_creates_source_blocks_correctly( assert len(files) == 1 assert files[0].source_id == source.id - # Check that the proper file associations were created - # files_agents = await server.file_agent_manager.list_files_for_agent(agent_id=agent_state.id, actor=actor) - # # Check that blocks were created - # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # blocks = client.agents.retrieve(agent_id=agent_state.id) # assert len(blocks) == 2 # assert any(expected_value in b.value for b in blocks) # assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)