fix(core): prevent ModelSettings default max_output_tokens from overriding agent config (#9739)
* fix(core): prevent ModelSettings default max_output_tokens from overriding agent config When a conversation's model_settings were saved, the Pydantic default of max_output_tokens=4096 was always persisted to the DB even when the client never specified it. On subsequent messages, this default would overwrite the agent's max_tokens (typically None) with 4096, silently capping output. Two changes: 1. Use model_dump(exclude_unset=True) when persisting model_settings to the DB so Pydantic defaults are not saved. 2. Add model_fields_set guards at all callsites that apply _to_legacy_config_params() to skip max_tokens when it was not explicitly provided by the caller. Also conditionally set max_output_tokens in the OpenAI Responses API request builder so None is not sent as null (which some models treat as a hard 4096 cap). * nit * Fix model_settings serialization to preserve provider_type discriminator Replace blanket exclude_unset=True with targeted removal of only max_output_tokens when not explicitly set. The previous approach stripped the provider_type field (a Literal with a default), which broke discriminated union deserialization when reading back from DB.
This commit is contained in:
@@ -389,7 +389,6 @@ class OpenAIClient(LLMClientBase):
|
|||||||
input=openai_messages_list,
|
input=openai_messages_list,
|
||||||
tools=responses_tools,
|
tools=responses_tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
max_output_tokens=llm_config.max_tokens,
|
|
||||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||||
parallel_tool_calls=llm_config.parallel_tool_calls if tools and supports_parallel_tool_calling(model) else False,
|
parallel_tool_calls=llm_config.parallel_tool_calls if tools and supports_parallel_tool_calling(model) else False,
|
||||||
)
|
)
|
||||||
@@ -397,6 +396,10 @@ class OpenAIClient(LLMClientBase):
|
|||||||
# Handle text configuration (verbosity and response format)
|
# Handle text configuration (verbosity and response format)
|
||||||
text_config_kwargs = {}
|
text_config_kwargs = {}
|
||||||
|
|
||||||
|
# Only set max_output_tokens if explicitly configured
|
||||||
|
if llm_config.max_tokens is not None:
|
||||||
|
data.max_output_tokens = llm_config.max_tokens
|
||||||
|
|
||||||
# Add verbosity control for GPT-5 models
|
# Add verbosity control for GPT-5 models
|
||||||
if supports_verbosity_control(model) and llm_config.verbosity:
|
if supports_verbosity_control(model) and llm_config.verbosity:
|
||||||
text_config_kwargs["verbosity"] = llm_config.verbosity
|
text_config_kwargs["verbosity"] = llm_config.verbosity
|
||||||
@@ -451,7 +454,6 @@ class OpenAIClient(LLMClientBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
request_data = data.model_dump(exclude_unset=True)
|
request_data = data.model_dump(exclude_unset=True)
|
||||||
# print("responses request data", request_data)
|
|
||||||
return request_data
|
return request_data
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
|
|||||||
@@ -401,6 +401,10 @@ async def send_conversation_message(
|
|||||||
)
|
)
|
||||||
if conversation.model_settings is not None:
|
if conversation.model_settings is not None:
|
||||||
update_params = conversation.model_settings._to_legacy_config_params()
|
update_params = conversation.model_settings._to_legacy_config_params()
|
||||||
|
# Don't clobber max_tokens with the Pydantic default when the caller
|
||||||
|
# didn't explicitly provide max_output_tokens.
|
||||||
|
if "max_output_tokens" not in conversation.model_settings.model_fields_set:
|
||||||
|
update_params.pop("max_tokens", None)
|
||||||
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
|
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
|
||||||
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
|
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
|
||||||
|
|
||||||
|
|||||||
@@ -562,6 +562,10 @@ class SyncServer(object):
|
|||||||
# update with model_settings
|
# update with model_settings
|
||||||
if request.model_settings is not None:
|
if request.model_settings is not None:
|
||||||
update_llm_config_params = request.model_settings._to_legacy_config_params()
|
update_llm_config_params = request.model_settings._to_legacy_config_params()
|
||||||
|
# Don't clobber max_tokens with the Pydantic default when the caller
|
||||||
|
# didn't explicitly provide max_output_tokens in the request.
|
||||||
|
if "max_output_tokens" not in request.model_settings.model_fields_set:
|
||||||
|
update_llm_config_params.pop("max_tokens", None)
|
||||||
request.llm_config = request.llm_config.model_copy(update=update_llm_config_params)
|
request.llm_config = request.llm_config.model_copy(update=update_llm_config_params)
|
||||||
|
|
||||||
# Copy parallel_tool_calls from request to llm_config if provided
|
# Copy parallel_tool_calls from request to llm_config if provided
|
||||||
|
|||||||
@@ -30,6 +30,21 @@ from letta.utils import enforce_types
|
|||||||
class ConversationManager:
|
class ConversationManager:
|
||||||
"""Manager class to handle business logic related to Conversations."""
|
"""Manager class to handle business logic related to Conversations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_model_settings(model_settings) -> Optional[dict]:
|
||||||
|
"""Serialize model settings for DB storage, stripping max_output_tokens if not explicitly set.
|
||||||
|
|
||||||
|
Uses model_dump() to preserve all fields (including the provider_type discriminator),
|
||||||
|
but removes max_output_tokens when it wasn't explicitly provided by the caller so we
|
||||||
|
don't persist the Pydantic default (4096) and later overwrite the agent's own value.
|
||||||
|
"""
|
||||||
|
if model_settings is None:
|
||||||
|
return None
|
||||||
|
data = model_settings.model_dump()
|
||||||
|
if "max_output_tokens" not in model_settings.model_fields_set:
|
||||||
|
data.pop("max_output_tokens", None)
|
||||||
|
return data
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
async def create_conversation(
|
async def create_conversation(
|
||||||
@@ -57,7 +72,7 @@ class ConversationManager:
|
|||||||
summary=conversation_create.summary,
|
summary=conversation_create.summary,
|
||||||
organization_id=actor.organization_id,
|
organization_id=actor.organization_id,
|
||||||
model=conversation_create.model,
|
model=conversation_create.model,
|
||||||
model_settings=conversation_create.model_settings.model_dump() if conversation_create.model_settings else None,
|
model_settings=self._serialize_model_settings(conversation_create.model_settings),
|
||||||
)
|
)
|
||||||
await conversation.create_async(session, actor=actor)
|
await conversation.create_async(session, actor=actor)
|
||||||
|
|
||||||
@@ -228,22 +243,15 @@ class ConversationManager:
|
|||||||
if sort_by == "last_run_completion":
|
if sort_by == "last_run_completion":
|
||||||
# Subquery to get the latest completed_at for each conversation
|
# Subquery to get the latest completed_at for each conversation
|
||||||
latest_run_subquery = (
|
latest_run_subquery = (
|
||||||
select(
|
select(RunModel.conversation_id, func.max(RunModel.completed_at).label("last_run_completion"))
|
||||||
RunModel.conversation_id,
|
|
||||||
func.max(RunModel.completed_at).label("last_run_completion")
|
|
||||||
)
|
|
||||||
.where(RunModel.conversation_id.isnot(None))
|
.where(RunModel.conversation_id.isnot(None))
|
||||||
.group_by(RunModel.conversation_id)
|
.group_by(RunModel.conversation_id)
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Join conversations with the subquery
|
# Join conversations with the subquery
|
||||||
stmt = (
|
stmt = select(ConversationModel).outerjoin(
|
||||||
select(ConversationModel)
|
latest_run_subquery, ConversationModel.id == latest_run_subquery.c.conversation_id
|
||||||
.outerjoin(
|
|
||||||
latest_run_subquery,
|
|
||||||
ConversationModel.id == latest_run_subquery.c.conversation_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
sort_column = latest_run_subquery.c.last_run_completion
|
sort_column = latest_run_subquery.c.last_run_completion
|
||||||
sort_nulls_last = True
|
sort_nulls_last = True
|
||||||
@@ -265,10 +273,12 @@ class ConversationManager:
|
|||||||
|
|
||||||
# Add summary search filter if provided
|
# Add summary search filter if provided
|
||||||
if summary_search:
|
if summary_search:
|
||||||
conditions.extend([
|
conditions.extend(
|
||||||
ConversationModel.summary.isnot(None),
|
[
|
||||||
ConversationModel.summary.contains(summary_search),
|
ConversationModel.summary.isnot(None),
|
||||||
])
|
ConversationModel.summary.contains(summary_search),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
stmt = stmt.where(and_(*conditions))
|
stmt = stmt.where(and_(*conditions))
|
||||||
|
|
||||||
@@ -277,10 +287,7 @@ class ConversationManager:
|
|||||||
# Get the sort value for the cursor conversation
|
# Get the sort value for the cursor conversation
|
||||||
if sort_by == "last_run_completion":
|
if sort_by == "last_run_completion":
|
||||||
cursor_query = (
|
cursor_query = (
|
||||||
select(
|
select(ConversationModel.id, func.max(RunModel.completed_at).label("last_run_completion"))
|
||||||
ConversationModel.id,
|
|
||||||
func.max(RunModel.completed_at).label("last_run_completion")
|
|
||||||
)
|
|
||||||
.outerjoin(RunModel, ConversationModel.id == RunModel.conversation_id)
|
.outerjoin(RunModel, ConversationModel.id == RunModel.conversation_id)
|
||||||
.where(ConversationModel.id == after)
|
.where(ConversationModel.id == after)
|
||||||
.group_by(ConversationModel.id)
|
.group_by(ConversationModel.id)
|
||||||
@@ -293,16 +300,11 @@ class ConversationManager:
|
|||||||
# Cursor is at NULL - if ascending, get non-NULLs or NULLs with greater ID
|
# Cursor is at NULL - if ascending, get non-NULLs or NULLs with greater ID
|
||||||
if ascending:
|
if ascending:
|
||||||
stmt = stmt.where(
|
stmt = stmt.where(
|
||||||
or_(
|
or_(and_(sort_column.is_(None), ConversationModel.id > after_id), sort_column.isnot(None))
|
||||||
and_(sort_column.is_(None), ConversationModel.id > after_id),
|
|
||||||
sort_column.isnot(None)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If descending, get NULLs with smaller ID
|
# If descending, get NULLs with smaller ID
|
||||||
stmt = stmt.where(
|
stmt = stmt.where(and_(sort_column.is_(None), ConversationModel.id < after_id))
|
||||||
and_(sort_column.is_(None), ConversationModel.id < after_id)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Cursor is at non-NULL
|
# Cursor is at non-NULL
|
||||||
if ascending:
|
if ascending:
|
||||||
@@ -312,8 +314,8 @@ class ConversationManager:
|
|||||||
sort_column.isnot(None),
|
sort_column.isnot(None),
|
||||||
or_(
|
or_(
|
||||||
sort_column > after_sort_value,
|
sort_column > after_sort_value,
|
||||||
and_(sort_column == after_sort_value, ConversationModel.id > after_id)
|
and_(sort_column == after_sort_value, ConversationModel.id > after_id),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -322,7 +324,7 @@ class ConversationManager:
|
|||||||
or_(
|
or_(
|
||||||
sort_column.is_(None),
|
sort_column.is_(None),
|
||||||
sort_column < after_sort_value,
|
sort_column < after_sort_value,
|
||||||
and_(sort_column == after_sort_value, ConversationModel.id < after_id)
|
and_(sort_column == after_sort_value, ConversationModel.id < after_id),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -372,7 +374,11 @@ class ConversationManager:
|
|||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
# model_settings needs to be serialized to dict for the JSON column
|
# model_settings needs to be serialized to dict for the JSON column
|
||||||
if key == "model_settings" and value is not None:
|
if key == "model_settings" and value is not None:
|
||||||
setattr(conversation, key, conversation_update.model_settings.model_dump() if conversation_update.model_settings else value)
|
setattr(
|
||||||
|
conversation,
|
||||||
|
key,
|
||||||
|
self._serialize_model_settings(conversation_update.model_settings) if conversation_update.model_settings else value,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
setattr(conversation, key, value)
|
setattr(conversation, key, value)
|
||||||
|
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ class StreamingService:
|
|||||||
)
|
)
|
||||||
if conversation.model_settings is not None:
|
if conversation.model_settings is not None:
|
||||||
update_params = conversation.model_settings._to_legacy_config_params()
|
update_params = conversation.model_settings._to_legacy_config_params()
|
||||||
|
# Don't clobber max_tokens with the Pydantic default when the caller
|
||||||
|
# didn't explicitly provide max_output_tokens.
|
||||||
|
if "max_output_tokens" not in conversation.model_settings.model_fields_set:
|
||||||
|
update_params.pop("max_tokens", None)
|
||||||
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
|
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
|
||||||
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
|
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,10 @@ async def build_summarizer_llm_config(
|
|||||||
# them just like server.create_agent_async does for agents.
|
# them just like server.create_agent_async does for agents.
|
||||||
if summarizer_config.model_settings is not None:
|
if summarizer_config.model_settings is not None:
|
||||||
update_params = summarizer_config.model_settings._to_legacy_config_params()
|
update_params = summarizer_config.model_settings._to_legacy_config_params()
|
||||||
|
# Don't clobber max_tokens with the Pydantic default when the caller
|
||||||
|
# didn't explicitly provide max_output_tokens.
|
||||||
|
if "max_output_tokens" not in summarizer_config.model_settings.model_fields_set:
|
||||||
|
update_params.pop("max_tokens", None)
|
||||||
return base.model_copy(update=update_params)
|
return base.model_copy(update=update_params)
|
||||||
|
|
||||||
return base
|
return base
|
||||||
|
|||||||
Reference in New Issue
Block a user