feat: support filtering out messages when converting to openai dict (#4337)
* feat: support filtering out messages when converting to openai dict * fix imports
This commit is contained in:
@@ -1106,7 +1106,7 @@ class Agent(BaseAgent):
|
||||
|
||||
def summarize_messages_inplace(self):
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
in_context_messages_openai_no_system = in_context_messages_openai[1:]
|
||||
token_counts = get_token_counts_for_messages(in_context_messages)
|
||||
logger.info(f"System message token count={token_counts[0]}")
|
||||
@@ -1212,7 +1212,7 @@ class Agent(BaseAgent):
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
|
||||
# Check if there's a summary message in the message queue
|
||||
if (
|
||||
@@ -1312,7 +1312,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
|
||||
# Extract system, memory and external summary
|
||||
if (
|
||||
|
||||
@@ -248,7 +248,7 @@ class CLIInterface(AgentInterface):
|
||||
@staticmethod
|
||||
def print_messages(message_sequence: List[Message], dump=False):
|
||||
# rewrite to dict format
|
||||
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
|
||||
message_sequence = Message.to_openai_dicts_from_list(message_sequence)
|
||||
|
||||
idx = len(message_sequence)
|
||||
for msg in message_sequence:
|
||||
@@ -291,7 +291,7 @@ class CLIInterface(AgentInterface):
|
||||
@staticmethod
|
||||
def print_messages_simple(message_sequence: List[Message]):
|
||||
# rewrite to dict format
|
||||
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
|
||||
message_sequence = Message.to_openai_dicts_from_list(message_sequence)
|
||||
|
||||
for msg in message_sequence:
|
||||
role = msg["role"]
|
||||
@@ -309,7 +309,7 @@ class CLIInterface(AgentInterface):
|
||||
@staticmethod
|
||||
def print_messages_raw(message_sequence: List[Message]):
|
||||
# rewrite to dict format
|
||||
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
|
||||
message_sequence = Message.to_openai_dicts_from_list(message_sequence)
|
||||
|
||||
for msg in message_sequence:
|
||||
print(msg)
|
||||
|
||||
@@ -113,6 +113,7 @@ class OpenAIStreamingInterface:
|
||||
if self.messages:
|
||||
# Convert messages to dict format for token counting
|
||||
message_dicts = [msg.to_openai_dict() if hasattr(msg, "to_openai_dict") else msg for msg in self.messages]
|
||||
message_dicts = [m for m in message_dicts if m is not None]
|
||||
self.fallback_input_tokens = num_tokens_from_messages(message_dicts) # fallback to gpt-4 cl100k-base
|
||||
|
||||
if self.tools:
|
||||
|
||||
@@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage, Message as _Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
AssistantMessage,
|
||||
ChatCompletionRequest,
|
||||
@@ -119,7 +119,9 @@ def build_deepseek_chat_completions_request(
|
||||
# inner_thoughts_description=inner_thoughts_desc,
|
||||
# )
|
||||
|
||||
openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages]
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(m) for m in PydanticMessage.to_openai_dicts_from_list(messages, put_inner_thoughts_in_kwargs=False)
|
||||
]
|
||||
|
||||
if llm_config.model:
|
||||
model = llm_config.model
|
||||
@@ -343,7 +345,9 @@ class DeepseekClient(OpenAIClient):
|
||||
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"
|
||||
system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.'
|
||||
|
||||
openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages]
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(m) for m in PydanticMessage.to_openai_dicts_from_list(messages, put_inner_thoughts_in_kwargs=False)
|
||||
]
|
||||
|
||||
if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively
|
||||
add_functions_to_system_message(
|
||||
|
||||
@@ -310,7 +310,7 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts
|
||||
f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}"
|
||||
)
|
||||
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
|
||||
if summarizer_settings.evict_all_messages:
|
||||
logger.info("Evicting all messages...")
|
||||
@@ -351,7 +351,7 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts
|
||||
|
||||
|
||||
def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]:
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
|
||||
return token_counts
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ def create(
|
||||
|
||||
# Count the tokens first, if there's an overflow exit early by throwing an error up the stack
|
||||
# NOTE: we want to include a specific substring in the error message to trigger summarization
|
||||
messages_oai_format = [m.to_openai_dict() for m in messages]
|
||||
messages_oai_format = Message.to_openai_dicts_from_list(messages)
|
||||
prompt_tokens = num_tokens_from_messages(messages=messages_oai_format, model=llm_config.model)
|
||||
function_tokens = num_tokens_from_functions(functions=functions, model=llm_config.model) if functions else 0
|
||||
if prompt_tokens + function_tokens > llm_config.context_window:
|
||||
|
||||
@@ -21,7 +21,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as _Message, MessageRole as _MessageRole
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
ChatCompletionRequest,
|
||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||
@@ -177,7 +177,7 @@ async def openai_get_model_list_async(
|
||||
|
||||
def build_openai_chat_completions_request(
|
||||
llm_config: LLMConfig,
|
||||
messages: List[_Message],
|
||||
messages: List[PydanticMessage],
|
||||
user_id: Optional[str],
|
||||
functions: Optional[list],
|
||||
function_call: Optional[str],
|
||||
@@ -201,13 +201,12 @@ def build_openai_chat_completions_request(
|
||||
use_developer_message = accepts_developer_role(llm_config.model)
|
||||
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(
|
||||
m.to_openai_dict(
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
)
|
||||
cast_message_to_subtype(m)
|
||||
for m in PydanticMessage.to_openai_dicts_from_list(
|
||||
messages,
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
if llm_config.model:
|
||||
@@ -326,7 +325,7 @@ def openai_chat_completions_process_stream(
|
||||
|
||||
# Create a dummy Message object to get an ID and date
|
||||
# TODO(sarah): add message ID generation function
|
||||
dummy_message = _Message(
|
||||
dummy_message = PydanticMessage(
|
||||
role=_MessageRole.assistant,
|
||||
content=[],
|
||||
agent_id="",
|
||||
|
||||
@@ -179,13 +179,12 @@ class OpenAIClient(LLMClientBase):
|
||||
use_developer_message = accepts_developer_role(llm_config.model)
|
||||
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(
|
||||
m.to_openai_dict(
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
)
|
||||
cast_message_to_subtype(m)
|
||||
for m in PydanticMessage.to_openai_dicts_from_list(
|
||||
messages,
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
if llm_config.model:
|
||||
|
||||
@@ -22,6 +22,7 @@ from letta.local_llm.webui.api import get_webui_completion
|
||||
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||
from letta.otel.tracing import log_event
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
@@ -61,7 +62,7 @@ def get_chat_completion(
|
||||
|
||||
# TODO: eventually just process Message object
|
||||
if not isinstance(messages[0], dict):
|
||||
messages = [m.to_openai_dict() for m in messages]
|
||||
messages = PydanticMessage.to_openai_dicts_from_list(messages)
|
||||
|
||||
if function_call is not None and function_call != "auto":
|
||||
raise ValueError(f"function_call == {function_call} not supported (auto or None only)")
|
||||
|
||||
@@ -730,7 +730,7 @@ class Message(BaseMessage):
|
||||
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
use_developer_message: bool = False,
|
||||
) -> dict:
|
||||
) -> dict | None:
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
|
||||
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
|
||||
@@ -822,6 +822,24 @@ class Message(BaseMessage):
|
||||
|
||||
return openai_message
|
||||
|
||||
@staticmethod
|
||||
def to_openai_dicts_from_list(
|
||||
messages: List[Message],
|
||||
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
use_developer_message: bool = False,
|
||||
) -> List[dict]:
|
||||
result = [
|
||||
m.to_openai_dict(
|
||||
max_tool_id_length=max_tool_id_length,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
result = [m for m in result if m is not None]
|
||||
return result
|
||||
|
||||
def to_anthropic_dict(
|
||||
self,
|
||||
inner_thoughts_xml_tag="thinking",
|
||||
|
||||
@@ -435,7 +435,9 @@ def convert_in_context_letta_messages_to_openai(in_context_messages: List[Messag
|
||||
pass # It's not JSON, leave as-is
|
||||
|
||||
# Finally, convert to dict using your existing method
|
||||
openai_messages.append(msg.to_openai_dict())
|
||||
m = msg.to_openai_dict()
|
||||
assert m is not None
|
||||
openai_messages.append(m)
|
||||
|
||||
return openai_messages
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.utils import count_tokens
|
||||
|
||||
@@ -124,4 +125,4 @@ class TiktokenCounter(TokenCounter):
|
||||
return num_tokens_from_functions(functions=functions, model=self.model)
|
||||
|
||||
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
return [m.to_openai_dict() for m in messages]
|
||||
return Message.to_openai_dicts_from_list(messages)
|
||||
|
||||
@@ -295,7 +295,9 @@ class Summarizer:
|
||||
def simple_formatter(messages: List[Message], include_system: bool = False) -> str:
|
||||
"""Go from an OpenAI-style list of messages to a concatenated string"""
|
||||
|
||||
parsed_messages = [message.to_openai_dict() for message in messages if message.role != MessageRole.system or include_system]
|
||||
parsed_messages = Message.to_openai_dicts_from_list(
|
||||
[message for message in messages if message.role != MessageRole.system or include_system]
|
||||
)
|
||||
return "\n".join(json.dumps(msg) for msg in parsed_messages)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user