feat: support filtering out messages when converting to openai dict (#4337)

* feat: support filtering out messages when converting to openai dict

* fix imports
This commit is contained in:
cthomas
2025-09-01 12:48:45 -07:00
committed by GitHub
parent f3112f75a3
commit 7c88470705
13 changed files with 59 additions and 32 deletions

View File

@@ -1106,7 +1106,7 @@ class Agent(BaseAgent):
def summarize_messages_inplace(self):
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
in_context_messages_openai_no_system = in_context_messages_openai[1:]
token_counts = get_token_counts_for_messages(in_context_messages)
logger.info(f"System message token count={token_counts[0]}")
@@ -1212,7 +1212,7 @@ class Agent(BaseAgent):
# Grab the in-context messages
# conversion of messages to OpenAI dict format, which is passed to the token counter
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
# Check if there's a summary message in the message queue
if (
@@ -1312,7 +1312,7 @@ class Agent(BaseAgent):
)
# conversion of messages to OpenAI dict format, which is passed to the token counter
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
# Extract system, memory and external summary
if (

View File

@@ -248,7 +248,7 @@ class CLIInterface(AgentInterface):
@staticmethod
def print_messages(message_sequence: List[Message], dump=False):
# rewrite to dict format
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
message_sequence = Message.to_openai_dicts_from_list(message_sequence)
idx = len(message_sequence)
for msg in message_sequence:
@@ -291,7 +291,7 @@ class CLIInterface(AgentInterface):
@staticmethod
def print_messages_simple(message_sequence: List[Message]):
# rewrite to dict format
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
message_sequence = Message.to_openai_dicts_from_list(message_sequence)
for msg in message_sequence:
role = msg["role"]
@@ -309,7 +309,7 @@ class CLIInterface(AgentInterface):
@staticmethod
def print_messages_raw(message_sequence: List[Message]):
# rewrite to dict format
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
message_sequence = Message.to_openai_dicts_from_list(message_sequence)
for msg in message_sequence:
print(msg)

View File

@@ -113,6 +113,7 @@ class OpenAIStreamingInterface:
if self.messages:
# Convert messages to dict format for token counting
message_dicts = [msg.to_openai_dict() if hasattr(msg, "to_openai_dict") else msg for msg in self.messages]
message_dicts = [m for m in message_dicts if m is not None]
self.fallback_input_tokens = num_tokens_from_messages(message_dicts) # fallback to gpt-4 cl100k-base
if self.tools:

View File

@@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage, Message as _Message
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import (
AssistantMessage,
ChatCompletionRequest,
@@ -119,7 +119,9 @@ def build_deepseek_chat_completions_request(
# inner_thoughts_description=inner_thoughts_desc,
# )
openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages]
openai_message_list = [
cast_message_to_subtype(m) for m in PydanticMessage.to_openai_dicts_from_list(messages, put_inner_thoughts_in_kwargs=False)
]
if llm_config.model:
model = llm_config.model
@@ -343,7 +345,9 @@ class DeepseekClient(OpenAIClient):
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"
system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.'
openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages]
openai_message_list = [
cast_message_to_subtype(m) for m in PydanticMessage.to_openai_dicts_from_list(messages, put_inner_thoughts_in_kwargs=False)
]
if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively
add_functions_to_system_message(

View File

@@ -310,7 +310,7 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts
f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}"
)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
if summarizer_settings.evict_all_messages:
logger.info("Evicting all messages...")
@@ -351,7 +351,7 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts
def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]:
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
return token_counts

View File

@@ -145,7 +145,7 @@ def create(
# Count the tokens first, if there's an overflow exit early by throwing an error up the stack
# NOTE: we want to include a specific substring in the error message to trigger summarization
messages_oai_format = [m.to_openai_dict() for m in messages]
messages_oai_format = Message.to_openai_dicts_from_list(messages)
prompt_tokens = num_tokens_from_messages(messages=messages_oai_format, model=llm_config.model)
function_tokens = num_tokens_from_functions(functions=functions, model=llm_config.model) if functions else 0
if prompt_tokens + function_tokens > llm_config.context_window:

View File

@@ -21,7 +21,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
from letta.log import get_logger
from letta.otel.tracing import log_event
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as _Message, MessageRole as _MessageRole
from letta.schemas.message import Message as PydanticMessage, MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import (
ChatCompletionRequest,
FunctionCall as ToolFunctionChoiceFunctionCall,
@@ -177,7 +177,7 @@ async def openai_get_model_list_async(
def build_openai_chat_completions_request(
llm_config: LLMConfig,
messages: List[_Message],
messages: List[PydanticMessage],
user_id: Optional[str],
functions: Optional[list],
function_call: Optional[str],
@@ -201,13 +201,12 @@ def build_openai_chat_completions_request(
use_developer_message = accepts_developer_role(llm_config.model)
openai_message_list = [
cast_message_to_subtype(
m.to_openai_dict(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
cast_message_to_subtype(m)
for m in PydanticMessage.to_openai_dicts_from_list(
messages,
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
for m in messages
]
if llm_config.model:
@@ -326,7 +325,7 @@ def openai_chat_completions_process_stream(
# Create a dummy Message object to get an ID and date
# TODO(sarah): add message ID generation function
dummy_message = _Message(
dummy_message = PydanticMessage(
role=_MessageRole.assistant,
content=[],
agent_id="",

View File

@@ -179,13 +179,12 @@ class OpenAIClient(LLMClientBase):
use_developer_message = accepts_developer_role(llm_config.model)
openai_message_list = [
cast_message_to_subtype(
m.to_openai_dict(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
cast_message_to_subtype(m)
for m in PydanticMessage.to_openai_dicts_from_list(
messages,
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
for m in messages
]
if llm_config.model:

View File

@@ -22,6 +22,7 @@ from letta.local_llm.webui.api import get_webui_completion
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
from letta.otel.tracing import log_event
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics
from letta.utils import get_tool_call_id
@@ -61,7 +62,7 @@ def get_chat_completion(
# TODO: eventually just process Message object
if not isinstance(messages[0], dict):
messages = [m.to_openai_dict() for m in messages]
messages = PydanticMessage.to_openai_dicts_from_list(messages)
if function_call is not None and function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto or None only)")

View File

@@ -730,7 +730,7 @@ class Message(BaseMessage):
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
put_inner_thoughts_in_kwargs: bool = False,
use_developer_message: bool = False,
) -> dict:
) -> dict | None:
"""Go from Message class to ChatCompletion message object"""
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
@@ -822,6 +822,24 @@ class Message(BaseMessage):
return openai_message
@staticmethod
def to_openai_dicts_from_list(
messages: List[Message],
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
put_inner_thoughts_in_kwargs: bool = False,
use_developer_message: bool = False,
) -> List[dict]:
result = [
m.to_openai_dict(
max_tool_id_length=max_tool_id_length,
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
for m in messages
]
result = [m for m in result if m is not None]
return result
def to_anthropic_dict(
self,
inner_thoughts_xml_tag="thinking",

View File

@@ -435,7 +435,9 @@ def convert_in_context_letta_messages_to_openai(in_context_messages: List[Messag
pass # It's not JSON, leave as-is
# Finally, convert to dict using your existing method
openai_messages.append(msg.to_openai_dict())
m = msg.to_openai_dict()
assert m is not None
openai_messages.append(m)
return openai_messages

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.anthropic_client import AnthropicClient
from letta.otel.tracing import trace_method
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.utils import count_tokens
@@ -124,4 +125,4 @@ class TiktokenCounter(TokenCounter):
return num_tokens_from_functions(functions=functions, model=self.model)
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return [m.to_openai_dict() for m in messages]
return Message.to_openai_dicts_from_list(messages)

View File

@@ -295,7 +295,9 @@ class Summarizer:
def simple_formatter(messages: List[Message], include_system: bool = False) -> str:
"""Go from an OpenAI-style list of messages to a concatenated string"""
parsed_messages = [message.to_openai_dict() for message in messages if message.role != MessageRole.system or include_system]
parsed_messages = Message.to_openai_dicts_from_list(
[message for message in messages if message.role != MessageRole.system or include_system]
)
return "\n".join(json.dumps(msg) for msg in parsed_messages)