feat(core): add model/model_settings override fields to conversation create/update (#9607)

This commit is contained in:
Sarah Wooders
2026-02-22 23:12:05 -08:00
committed by Caren Thomas
parent a9a6a5f29d
commit afbc416972
7 changed files with 298 additions and 3 deletions

View File

@@ -1,18 +1,22 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import ForeignKey, Index, String
from pydantic import TypeAdapter
from sqlalchemy import JSON, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.conversation import Conversation as PydanticConversation
from letta.schemas.model import ModelSettingsUnion
if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.block import Block
from letta.orm.conversation_messages import ConversationMessage
_model_settings_adapter = TypeAdapter(ModelSettingsUnion)
class Conversation(SqlalchemyBase, OrganizationMixin):
"""Conversations that can be created on an agent for concurrent messaging."""
@@ -27,6 +31,12 @@ class Conversation(SqlalchemyBase, OrganizationMixin):
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}")
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation")
model: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="Model handle override for this conversation (format: provider/model-name)"
)
model_settings: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, doc="Model settings override for this conversation (provider-specific settings)"
)
# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise")
@@ -55,4 +65,6 @@ class Conversation(SqlalchemyBase, OrganizationMixin):
created_by_id=self.created_by_id,
last_updated_by_id=self.last_updated_by_id,
isolated_block_ids=[b.id for b in self.isolated_blocks] if self.isolated_blocks else [],
model=self.model,
model_settings=_model_settings_adapter.validate_python(self.model_settings) if self.model_settings else None,
)

View File

@@ -1,8 +1,10 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from letta.errors import LettaInvalidArgumentError
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.model import ModelSettingsUnion
class Conversation(OrmMetadataBase):
@@ -18,6 +20,14 @@ class Conversation(OrmMetadataBase):
default_factory=list,
description="IDs of blocks that are isolated (specific to this conversation, overriding agent defaults).",
)
model: Optional[str] = Field(
None,
description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.",
)
model_settings: Optional[ModelSettingsUnion] = Field(
None,
description="The model settings for this conversation (overrides agent's model settings).",
)
class CreateConversation(BaseModel):
@@ -29,9 +39,49 @@ class CreateConversation(BaseModel):
description="List of block labels that should be isolated (conversation-specific) rather than shared across conversations. "
"New blocks will be created as copies of the agent's blocks with these labels.",
)
model: Optional[str] = Field(
None,
description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.",
)
model_settings: Optional[ModelSettingsUnion] = Field(
None,
description="The model settings for this conversation (overrides agent's model settings).",
)
@field_validator("model")
@classmethod
def validate_model(cls, model: Optional[str]) -> Optional[str]:
if not model:
return model
if "/" not in model:
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
provider_name, model_name = model.split("/", 1)
if not provider_name or not model_name:
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
return model
class UpdateConversation(BaseModel):
"""Request model for updating a conversation."""
summary: Optional[str] = Field(None, description="A summary of the conversation.")
model: Optional[str] = Field(
None,
description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.",
)
model_settings: Optional[ModelSettingsUnion] = Field(
None,
description="The model settings for this conversation (overrides agent's model settings).",
)
@field_validator("model")
@classmethod
def validate_model(cls, model: Optional[str]) -> Optional[str]:
if not model:
return model
if "/" not in model:
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
provider_name, model_name = model.split("/", 1)
if not provider_name or not model_name:
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
return model

View File

@@ -250,6 +250,17 @@ async def send_conversation_message(
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Apply conversation-level model override if set (lower priority than request override)
if conversation.model and not request.override_model:
conversation_llm_config = await server.get_llm_config_from_handle_async(
actor=actor,
handle=conversation.model,
)
if conversation.model_settings is not None:
update_params = conversation.model_settings._to_legacy_config_params()
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
if request.override_model:
override_llm_config = await server.get_llm_config_from_handle_async(
actor=actor,

View File

@@ -54,6 +54,8 @@ class ConversationManager:
agent_id=agent_id,
summary=conversation_create.summary,
organization_id=actor.organization_id,
model=conversation_create.model,
model_settings=conversation_create.model_settings.model_dump() if conversation_create.model_settings else None,
)
await conversation.create_async(session, actor=actor)
@@ -185,7 +187,11 @@ class ConversationManager:
# Set attributes on the model
update_data = conversation_update.model_dump(exclude_none=True)
for key, value in update_data.items():
setattr(conversation, key, value)
# model_settings needs to be serialized to dict for the JSON column
if key == "model_settings" and value is not None:
setattr(conversation, key, conversation_update.model_settings.model_dump() if conversation_update.model_settings else value)
else:
setattr(conversation, key, value)
# Commit the update
updated_conversation = await conversation.update_async(

View File

@@ -45,6 +45,7 @@ from letta.server.rest_api.streaming_response import (
get_cancellation_event_for_run,
)
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.conversation_manager import ConversationManager
from letta.services.run_manager import RunManager
from letta.settings import settings
from letta.utils import safe_create_task
@@ -102,6 +103,22 @@ class StreamingService:
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Apply conversation-level model override if set (lower priority than request override)
if conversation_id and not request.override_model:
conversation = await ConversationManager().get_conversation_by_id(
conversation_id=conversation_id,
actor=actor,
)
if conversation.model:
conversation_llm_config = await self.server.get_llm_config_from_handle_async(
actor=actor,
handle=conversation.model,
)
if conversation.model_settings is not None:
update_params = conversation.model_settings._to_legacy_config_params()
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
# Handle model override if specified in the request
if request.override_model:
override_llm_config = await self.server.get_llm_config_from_handle_async(