diff --git a/letta/agent.py b/letta/agent.py index 5720a4e6..55173cfb 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -23,7 +23,7 @@ from letta.errors import LLMError from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.llm_api_tools import create -from letta.local_llm.utils import num_tokens_from_messages +from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.memory import ArchivalMemory, RecallMemory, summarize_messages from letta.metadata import MetadataStore from letta.persistence_manager import LocalStateManager @@ -33,6 +33,9 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, UpdateMessage +from letta.schemas.openai.chat_completion_request import ( + Tool as ChatCompletionRequestTool, +) from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import ( Message as ChatCompletionMessage, @@ -1458,6 +1461,24 @@ class Agent(BaseAgent): ) num_tokens_external_memory_summary = count_tokens(external_memory_summary) + # tokens taken up by function definitions + if self.functions: + available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in self.functions] + num_tokens_available_functions_definitions = num_tokens_from_functions(functions=self.functions, model=self.model) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions = 0 + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(self._messages), @@ -1466,7 +1487,7 @@ class Agent(BaseAgent): num_tokens_external_memory_summary=num_tokens_external_memory_summary, # top-level information context_window_size_max=self.agent_state.llm_config.context_window, - context_window_size_current=num_tokens_system + num_tokens_core_memory + num_tokens_summary_memory + num_tokens_messages, + context_window_size_current=num_tokens_used_total, # context window breakdown (in tokens) num_tokens_system=num_tokens_system, system_prompt=system_prompt, @@ -1476,6 +1497,9 @@ class Agent(BaseAgent): summary_memory=summary_memory, num_tokens_messages=num_tokens_messages, messages=self._messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, ) diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 55370de5..45768fbb 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -18,8 +18,13 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as _Message from letta.schemas.message import MessageRole as _MessageRole +from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import ( - ChatCompletionRequest, + FunctionCall as ToolFunctionChoiceFunctionCall, +) +from letta.schemas.openai.chat_completion_request import ( + Tool, + ToolFunctionChoice, cast_message_to_subtype, ) from letta.schemas.openai.chat_completion_response import ( @@ -100,10 +105,10 @@ def openai_get_model_list( def build_openai_chat_completions_request( llm_config: LLMConfig, - messages: List[Message], + messages: List[_Message], user_id: Optional[str], functions: Optional[list], - function_call: str, + function_call: Optional[str], use_tool_naming: bool, max_tokens: Optional[int], ) -> ChatCompletionRequest: @@ -124,11 +129,17 @@ def build_openai_chat_completions_request( model = None if use_tool_naming: + if function_call is None: + tool_choice = None + elif function_call not in ["none", "auto", "required"]: + tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=function_call)) + else: + tool_choice = function_call data = ChatCompletionRequest( model=model, messages=openai_message_list, - tools=[{"type": "function", "function": f} for f in functions] if functions else None, - tool_choice=function_call, + tools=[Tool(type="function", function=f) for f in functions] if functions else None, + tool_choice=tool_choice, user=str(user_id), max_tokens=max_tokens, ) diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index 6763d0c0..2b2c153b 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -1,6 +1,6 @@ import os import warnings -from typing import List +from typing import List, Union import requests import tiktoken @@ -11,6 +11,7 @@ import letta.local_llm.llm_chat_completion_wrappers.configurable_wrapper as conf import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3 import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr +from letta.schemas.openai.chat_completion_request import Tool, ToolCall def post_json_auth_request(uri, json_payload, auth_type, auth_key): @@ -123,7 +124,7 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): return num_tokens -def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"): +def num_tokens_from_tool_calls(tool_calls: Union[List[dict], List[ToolCall]], model: str = "gpt-4"): """Based on above code (num_tokens_from_functions). Example to encode: @@ -144,10 +145,25 @@ def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"): num_tokens = 0 for tool_call in tool_calls: - function_tokens = len(encoding.encode(tool_call["id"])) - function_tokens += 2 + len(encoding.encode(tool_call["type"])) - function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"])) - function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"])) + if isinstance(tool_call, dict): + tool_call_id = tool_call["id"] + tool_call_type = tool_call["type"] + tool_call_function = tool_call["function"] + tool_call_function_name = tool_call_function["name"] + tool_call_function_arguments = tool_call_function["arguments"] + elif isinstance(tool_call, Tool): + tool_call_id = tool_call.id + tool_call_type = tool_call.type + tool_call_function = tool_call.function + tool_call_function_name = tool_call_function.name + tool_call_function_arguments = tool_call_function.arguments + else: + raise ValueError(f"Unknown tool call type: {type(tool_call)}") + + function_tokens = len(encoding.encode(tool_call_id)) + function_tokens += 2 + len(encoding.encode(tool_call_type)) + function_tokens += 2 + len(encoding.encode(tool_call_function_name)) + function_tokens += 2 + len(encoding.encode(tool_call_function_arguments)) num_tokens += function_tokens diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 91d52cc3..a6c5ad02 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from letta.schemas.block import Block from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_request import Tool class ContextWindowOverview(BaseModel): @@ -41,6 +42,9 @@ class ContextWindowOverview(BaseModel): num_tokens_summary_memory: int = Field(..., description="The number of tokens in the summary memory.") summary_memory: Optional[str] = Field(None, description="The content of the summary memory.") + num_tokens_functions_definitions: int = Field(..., description="The number of tokens in the functions definitions.") + functions_definitions: Optional[List[Tool]] = Field(..., description="The content of the functions definitions.") + num_tokens_messages: int = Field(..., description="The number of tokens in the messages list.") # TODO make list of messages? # messages: List[dict] = Field(..., description="The messages in the context window.") diff --git a/letta/schemas/openai/chat_completion_request.py b/letta/schemas/openai/chat_completion_request.py index fb2d5dc4..5b7b2743 100644 --- a/letta/schemas/openai/chat_completion_request.py +++ b/letta/schemas/openai/chat_completion_request.py @@ -74,7 +74,7 @@ class ToolFunctionChoice(BaseModel): function: FunctionCall -ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice] +ToolChoice = Union[Literal["none", "auto", "required"], ToolFunctionChoice] ## tools ## @@ -117,7 +117,7 @@ class ChatCompletionRequest(BaseModel): # function-calling related tools: Optional[List[Tool]] = None - tool_choice: Optional[ToolChoice] = "none" + tool_choice: Optional[ToolChoice] = None # "none" means don't call a tool # deprecated scheme functions: Optional[List[FunctionSchema]] = None function_call: Optional[FunctionCallChoice] = None diff --git a/letta/server/server.py b/letta/server/server.py index ff136348..bf1ac91e 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -73,7 +73,12 @@ from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary +from letta.schemas.memory import ( + ArchivalMemorySummary, + ContextWindowOverview, + Memory, + RecallMemorySummary, +) from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.passage import Passage @@ -2177,3 +2182,12 @@ class SyncServer(Server): def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" + + def get_agent_context_window( + self, + user_id: str, + agent_id: str, + ) -> ContextWindowOverview: + # Get the current message + letta_agent = self._get_or_load_agent(agent_id=agent_id) + return letta_agent.get_context_window() diff --git a/tests/test_server.py b/tests/test_server.py index cf3dc3ec..9285b25e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -507,3 +507,43 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) print(args_json) assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text + + +def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str): + """Test that the context window overview fetch works""" + + overview = server.get_agent_context_window(user_id=user_id, agent_id=agent_id) + assert overview is not None + + # Run some basic checks + assert overview.context_window_size_max is not None + assert overview.context_window_size_current is not None + assert overview.num_archival_memory is not None + assert overview.num_recall_memory is not None + assert overview.num_tokens_external_memory_summary is not None + assert overview.num_tokens_system is not None + assert overview.system_prompt is not None + assert overview.num_tokens_core_memory is not None + assert overview.core_memory is not None + assert overview.num_tokens_summary_memory is not None + if overview.num_tokens_summary_memory > 0: + assert overview.summary_memory is not None + else: + assert overview.summary_memory is None + assert overview.num_tokens_functions_definitions is not None + if overview.num_tokens_functions_definitions > 0: + assert overview.functions_definitions is not None + else: + assert overview.functions_definitions is None + assert overview.num_tokens_messages is not None + assert overview.messages is not None + + assert overview.context_window_size_max >= overview.context_window_size_current + assert overview.context_window_size_current == ( + overview.num_tokens_system + + overview.num_tokens_core_memory + + overview.num_tokens_summary_memory + + overview.num_tokens_messages + + overview.num_tokens_functions_definitions + + overview.num_tokens_external_memory_summary + )