From 7c88470705e3086d1bbd8bcf14d7a6b072a75b4f Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 1 Sep 2025 12:48:45 -0700 Subject: [PATCH] feat: support filtering out messages when converting to openai dict (#4337) * feat: support filtering out messages when converting to openai dict * fix imports --- letta/agent.py | 6 +++--- letta/interface.py | 6 +++--- .../interfaces/openai_streaming_interface.py | 1 + letta/llm_api/deepseek_client.py | 10 +++++++--- letta/llm_api/helpers.py | 4 ++-- letta/llm_api/llm_api_tools.py | 2 +- letta/llm_api/openai.py | 17 ++++++++-------- letta/llm_api/openai_client.py | 11 +++++----- letta/local_llm/chat_completion_proxy.py | 3 ++- letta/schemas/message.py | 20 ++++++++++++++++++- letta/server/rest_api/utils.py | 4 +++- .../token_counter.py | 3 ++- letta/services/summarizer/summarizer.py | 4 +++- 13 files changed, 59 insertions(+), 32 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index cbfa7f6c..99b48f1f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1106,7 +1106,7 @@ class Agent(BaseAgent): def summarize_messages_inplace(self): in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) in_context_messages_openai_no_system = in_context_messages_openai[1:] token_counts = get_token_counts_for_messages(in_context_messages) logger.info(f"System message token count={token_counts[0]}") @@ -1212,7 +1212,7 @@ class Agent(BaseAgent): # Grab the in-context messages # conversion of messages to OpenAI dict format, which is passed to the token counter in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) # Check if there's a summary message in the message queue if ( @@ -1312,7 +1312,7 @@ class Agent(BaseAgent): ) # conversion of messages to OpenAI dict format, which is passed to the token counter - in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) # Extract system, memory and external summary if ( diff --git a/letta/interface.py b/letta/interface.py index 733aaddb..7e146b07 100644 --- a/letta/interface.py +++ b/letta/interface.py @@ -248,7 +248,7 @@ class CLIInterface(AgentInterface): @staticmethod def print_messages(message_sequence: List[Message], dump=False): # rewrite to dict format - message_sequence = [msg.to_openai_dict() for msg in message_sequence] + message_sequence = Message.to_openai_dicts_from_list(message_sequence) idx = len(message_sequence) for msg in message_sequence: @@ -291,7 +291,7 @@ class CLIInterface(AgentInterface): @staticmethod def print_messages_simple(message_sequence: List[Message]): # rewrite to dict format - message_sequence = [msg.to_openai_dict() for msg in message_sequence] + message_sequence = Message.to_openai_dicts_from_list(message_sequence) for msg in message_sequence: role = msg["role"] @@ -309,7 +309,7 @@ class CLIInterface(AgentInterface): @staticmethod def print_messages_raw(message_sequence: List[Message]): # rewrite to dict format - message_sequence = [msg.to_openai_dict() for msg in message_sequence] + message_sequence = Message.to_openai_dicts_from_list(message_sequence) for msg in message_sequence: print(msg) diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 7bf3a1a3..08f23e0f 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -113,6 +113,7 @@ class OpenAIStreamingInterface: if self.messages: # Convert messages to dict format for token counting message_dicts = [msg.to_openai_dict() if hasattr(msg, "to_openai_dict") else msg for msg in self.messages] + message_dicts = [m for m in message_dicts if m is not None] self.fallback_input_tokens = num_tokens_from_messages(message_dicts) # fallback to gpt-4 cl100k-base if self.tools: diff --git a/letta/llm_api/deepseek_client.py b/letta/llm_api/deepseek_client.py index 7d02fcd9..a0037b1e 100644 --- a/letta/llm_api/deepseek_client.py +++ b/letta/llm_api/deepseek_client.py @@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from letta.llm_api.openai_client import OpenAIClient from letta.otel.tracing import trace_method from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message as PydanticMessage, Message as _Message +from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ( AssistantMessage, ChatCompletionRequest, @@ -119,7 +119,9 @@ def build_deepseek_chat_completions_request( # inner_thoughts_description=inner_thoughts_desc, # ) - openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages] + openai_message_list = [ + cast_message_to_subtype(m) for m in PydanticMessage.to_openai_dicts_from_list(messages, put_inner_thoughts_in_kwargs=False) + ] if llm_config.model: model = llm_config.model @@ -343,7 +345,9 @@ class DeepseekClient(OpenAIClient): system_message.content += f" {''.join(json.dumps(f) for f in tools)} " system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.' - openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages] + openai_message_list = [ + cast_message_to_subtype(m) for m in PydanticMessage.to_openai_dicts_from_list(messages, put_inner_thoughts_in_kwargs=False) + ] if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively add_functions_to_system_message( diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index b6f3acb8..c87ec188 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -310,7 +310,7 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}" ) - in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) if summarizer_settings.evict_all_messages: logger.info("Evicting all messages...") @@ -351,7 +351,7 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]: - in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai] return token_counts diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 3050bec6..8a75bc7b 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -145,7 +145,7 @@ def create( # Count the tokens first, if there's an overflow exit early by throwing an error up the stack # NOTE: we want to include a specific substring in the error message to trigger summarization - messages_oai_format = [m.to_openai_dict() for m in messages] + messages_oai_format = Message.to_openai_dicts_from_list(messages) prompt_tokens = num_tokens_from_messages(messages=messages_oai_format, model=llm_config.model) function_tokens = num_tokens_from_functions(functions=functions, model=llm_config.model) if functions else 0 if prompt_tokens + function_tokens > llm_config.context_window: diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index b113b7cd..da1e5a8d 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -21,7 +21,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes from letta.log import get_logger from letta.otel.tracing import log_event from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message as _Message, MessageRole as _MessageRole +from letta.schemas.message import Message as PydanticMessage, MessageRole as _MessageRole from letta.schemas.openai.chat_completion_request import ( ChatCompletionRequest, FunctionCall as ToolFunctionChoiceFunctionCall, @@ -177,7 +177,7 @@ async def openai_get_model_list_async( def build_openai_chat_completions_request( llm_config: LLMConfig, - messages: List[_Message], + messages: List[PydanticMessage], user_id: Optional[str], functions: Optional[list], function_call: Optional[str], @@ -201,13 +201,12 @@ def build_openai_chat_completions_request( use_developer_message = accepts_developer_role(llm_config.model) openai_message_list = [ - cast_message_to_subtype( - m.to_openai_dict( - put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, - use_developer_message=use_developer_message, - ) + cast_message_to_subtype(m) + for m in PydanticMessage.to_openai_dicts_from_list( + messages, + put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, + use_developer_message=use_developer_message, ) - for m in messages ] if llm_config.model: @@ -326,7 +325,7 @@ def openai_chat_completions_process_stream( # Create a dummy Message object to get an ID and date # TODO(sarah): add message ID generation function - dummy_message = _Message( + dummy_message = PydanticMessage( role=_MessageRole.assistant, content=[], agent_id="", diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index fb541210..43ea6dc4 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -179,13 +179,12 @@ class OpenAIClient(LLMClientBase): use_developer_message = accepts_developer_role(llm_config.model) openai_message_list = [ - cast_message_to_subtype( - m.to_openai_dict( - put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, - use_developer_message=use_developer_message, - ) + cast_message_to_subtype(m) + for m in PydanticMessage.to_openai_dicts_from_list( + messages, + put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, + use_developer_message=use_developer_message, ) - for m in messages ] if llm_config.model: diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index ba0ab45a..1129b125 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -22,6 +22,7 @@ from letta.local_llm.webui.api import get_webui_completion from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy from letta.otel.tracing import log_event from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE +from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics from letta.utils import get_tool_call_id @@ -61,7 +62,7 @@ def get_chat_completion( # TODO: eventually just process Message object if not isinstance(messages[0], dict): - messages = [m.to_openai_dict() for m in messages] + messages = PydanticMessage.to_openai_dicts_from_list(messages) if function_call is not None and function_call != "auto": raise ValueError(f"function_call == {function_call} not supported (auto or None only)") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index ed0af739..ee5430d6 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -730,7 +730,7 @@ class Message(BaseMessage): max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN, put_inner_thoughts_in_kwargs: bool = False, use_developer_message: bool = False, - ) -> dict: + ) -> dict | None: """Go from Message class to ChatCompletion message object""" # TODO change to pydantic casting, eg `return SystemMessageModel(self)` @@ -822,6 +822,24 @@ class Message(BaseMessage): return openai_message + @staticmethod + def to_openai_dicts_from_list( + messages: List[Message], + max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN, + put_inner_thoughts_in_kwargs: bool = False, + use_developer_message: bool = False, + ) -> List[dict]: + result = [ + m.to_openai_dict( + max_tool_id_length=max_tool_id_length, + put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + use_developer_message=use_developer_message, + ) + for m in messages + ] + result = [m for m in result if m is not None] + return result + def to_anthropic_dict( self, inner_thoughts_xml_tag="thinking", diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index a193e57b..c72b2513 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -435,7 +435,9 @@ def convert_in_context_letta_messages_to_openai(in_context_messages: List[Messag pass # It's not JSON, leave as-is # Finally, convert to dict using your existing method - openai_messages.append(msg.to_openai_dict()) + m = msg.to_openai_dict() + assert m is not None + openai_messages.append(m) return openai_messages diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 244a8a52..c432d72e 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient from letta.otel.tracing import trace_method +from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.utils import count_tokens @@ -124,4 +125,4 @@ class TiktokenCounter(TokenCounter): return num_tokens_from_functions(functions=functions, model=self.model) def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: - return [m.to_openai_dict() for m in messages] + return Message.to_openai_dicts_from_list(messages) diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 575bf351..3e4d040a 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -295,7 +295,9 @@ class Summarizer: def simple_formatter(messages: List[Message], include_system: bool = False) -> str: """Go from an OpenAI-style list of messages to a concatenated string""" - parsed_messages = [message.to_openai_dict() for message in messages if message.role != MessageRole.system or include_system] + parsed_messages = Message.to_openai_dicts_from_list( + [message for message in messages if message.role != MessageRole.system or include_system] + ) return "\n".join(json.dumps(msg) for msg in parsed_messages)