feat(core): add model/model_settings override fields to conversation create/update (#9607)
This commit is contained in:
committed by
Caren Thomas
parent
a9a6a5f29d
commit
afbc416972
@@ -0,0 +1,29 @@
|
|||||||
|
"""Add model and model_settings columns to conversations table for model overrides
|
||||||
|
|
||||||
|
Revision ID: b2c3d4e5f6a8
|
||||||
|
Revises: 3e54e2fa2f7e
|
||||||
|
Create Date: 2026-02-23 02:50:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b2c3d4e5f6a8"
|
||||||
|
down_revision: Union[str, None] = "3e54e2fa2f7e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("conversations", sa.Column("model", sa.String(), nullable=True))
|
||||||
|
op.add_column("conversations", sa.Column("model_settings", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("conversations", "model_settings")
|
||||||
|
op.drop_column("conversations", "model")
|
||||||
@@ -1,18 +1,22 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Index, String
|
from pydantic import TypeAdapter
|
||||||
|
from sqlalchemy import JSON, ForeignKey, Index, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.mixins import OrganizationMixin
|
from letta.orm.mixins import OrganizationMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.conversation import Conversation as PydanticConversation
|
from letta.schemas.conversation import Conversation as PydanticConversation
|
||||||
|
from letta.schemas.model import ModelSettingsUnion
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.agent import Agent
|
from letta.orm.agent import Agent
|
||||||
from letta.orm.block import Block
|
from letta.orm.block import Block
|
||||||
from letta.orm.conversation_messages import ConversationMessage
|
from letta.orm.conversation_messages import ConversationMessage
|
||||||
|
|
||||||
|
_model_settings_adapter = TypeAdapter(ModelSettingsUnion)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(SqlalchemyBase, OrganizationMixin):
|
class Conversation(SqlalchemyBase, OrganizationMixin):
|
||||||
"""Conversations that can be created on an agent for concurrent messaging."""
|
"""Conversations that can be created on an agent for concurrent messaging."""
|
||||||
@@ -27,6 +31,12 @@ class Conversation(SqlalchemyBase, OrganizationMixin):
|
|||||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}")
|
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}")
|
||||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
|
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
|
||||||
summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation")
|
summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation")
|
||||||
|
model: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String, nullable=True, doc="Model handle override for this conversation (format: provider/model-name)"
|
||||||
|
)
|
||||||
|
model_settings: Mapped[Optional[dict]] = mapped_column(
|
||||||
|
JSON, nullable=True, doc="Model settings override for this conversation (provider-specific settings)"
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise")
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise")
|
||||||
@@ -55,4 +65,6 @@ class Conversation(SqlalchemyBase, OrganizationMixin):
|
|||||||
created_by_id=self.created_by_id,
|
created_by_id=self.created_by_id,
|
||||||
last_updated_by_id=self.last_updated_by_id,
|
last_updated_by_id=self.last_updated_by_id,
|
||||||
isolated_block_ids=[b.id for b in self.isolated_blocks] if self.isolated_blocks else [],
|
isolated_block_ids=[b.id for b in self.isolated_blocks] if self.isolated_blocks else [],
|
||||||
|
model=self.model,
|
||||||
|
model_settings=_model_settings_adapter.validate_python(self.model_settings) if self.model_settings else None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from letta.errors import LettaInvalidArgumentError
|
||||||
from letta.schemas.letta_base import OrmMetadataBase
|
from letta.schemas.letta_base import OrmMetadataBase
|
||||||
|
from letta.schemas.model import ModelSettingsUnion
|
||||||
|
|
||||||
|
|
||||||
class Conversation(OrmMetadataBase):
|
class Conversation(OrmMetadataBase):
|
||||||
@@ -18,6 +20,14 @@ class Conversation(OrmMetadataBase):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="IDs of blocks that are isolated (specific to this conversation, overriding agent defaults).",
|
description="IDs of blocks that are isolated (specific to this conversation, overriding agent defaults).",
|
||||||
)
|
)
|
||||||
|
model: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.",
|
||||||
|
)
|
||||||
|
model_settings: Optional[ModelSettingsUnion] = Field(
|
||||||
|
None,
|
||||||
|
description="The model settings for this conversation (overrides agent's model settings).",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CreateConversation(BaseModel):
|
class CreateConversation(BaseModel):
|
||||||
@@ -29,9 +39,49 @@ class CreateConversation(BaseModel):
|
|||||||
description="List of block labels that should be isolated (conversation-specific) rather than shared across conversations. "
|
description="List of block labels that should be isolated (conversation-specific) rather than shared across conversations. "
|
||||||
"New blocks will be created as copies of the agent's blocks with these labels.",
|
"New blocks will be created as copies of the agent's blocks with these labels.",
|
||||||
)
|
)
|
||||||
|
model: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.",
|
||||||
|
)
|
||||||
|
model_settings: Optional[ModelSettingsUnion] = Field(
|
||||||
|
None,
|
||||||
|
description="The model settings for this conversation (overrides agent's model settings).",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("model")
|
||||||
|
@classmethod
|
||||||
|
def validate_model(cls, model: Optional[str]) -> Optional[str]:
|
||||||
|
if not model:
|
||||||
|
return model
|
||||||
|
if "/" not in model:
|
||||||
|
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
|
||||||
|
provider_name, model_name = model.split("/", 1)
|
||||||
|
if not provider_name or not model_name:
|
||||||
|
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class UpdateConversation(BaseModel):
|
class UpdateConversation(BaseModel):
|
||||||
"""Request model for updating a conversation."""
|
"""Request model for updating a conversation."""
|
||||||
|
|
||||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||||
|
model: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.",
|
||||||
|
)
|
||||||
|
model_settings: Optional[ModelSettingsUnion] = Field(
|
||||||
|
None,
|
||||||
|
description="The model settings for this conversation (overrides agent's model settings).",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("model")
|
||||||
|
@classmethod
|
||||||
|
def validate_model(cls, model: Optional[str]) -> Optional[str]:
|
||||||
|
if not model:
|
||||||
|
return model
|
||||||
|
if "/" not in model:
|
||||||
|
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
|
||||||
|
provider_name, model_name = model.split("/", 1)
|
||||||
|
if not provider_name or not model_name:
|
||||||
|
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
|
||||||
|
return model
|
||||||
|
|||||||
@@ -250,6 +250,17 @@ async def send_conversation_message(
|
|||||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply conversation-level model override if set (lower priority than request override)
|
||||||
|
if conversation.model and not request.override_model:
|
||||||
|
conversation_llm_config = await server.get_llm_config_from_handle_async(
|
||||||
|
actor=actor,
|
||||||
|
handle=conversation.model,
|
||||||
|
)
|
||||||
|
if conversation.model_settings is not None:
|
||||||
|
update_params = conversation.model_settings._to_legacy_config_params()
|
||||||
|
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
|
||||||
|
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
|
||||||
|
|
||||||
if request.override_model:
|
if request.override_model:
|
||||||
override_llm_config = await server.get_llm_config_from_handle_async(
|
override_llm_config = await server.get_llm_config_from_handle_async(
|
||||||
actor=actor,
|
actor=actor,
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ class ConversationManager:
|
|||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
summary=conversation_create.summary,
|
summary=conversation_create.summary,
|
||||||
organization_id=actor.organization_id,
|
organization_id=actor.organization_id,
|
||||||
|
model=conversation_create.model,
|
||||||
|
model_settings=conversation_create.model_settings.model_dump() if conversation_create.model_settings else None,
|
||||||
)
|
)
|
||||||
await conversation.create_async(session, actor=actor)
|
await conversation.create_async(session, actor=actor)
|
||||||
|
|
||||||
@@ -185,7 +187,11 @@ class ConversationManager:
|
|||||||
# Set attributes on the model
|
# Set attributes on the model
|
||||||
update_data = conversation_update.model_dump(exclude_none=True)
|
update_data = conversation_update.model_dump(exclude_none=True)
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(conversation, key, value)
|
# model_settings needs to be serialized to dict for the JSON column
|
||||||
|
if key == "model_settings" and value is not None:
|
||||||
|
setattr(conversation, key, conversation_update.model_settings.model_dump() if conversation_update.model_settings else value)
|
||||||
|
else:
|
||||||
|
setattr(conversation, key, value)
|
||||||
|
|
||||||
# Commit the update
|
# Commit the update
|
||||||
updated_conversation = await conversation.update_async(
|
updated_conversation = await conversation.update_async(
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from letta.server.rest_api.streaming_response import (
|
|||||||
get_cancellation_event_for_run,
|
get_cancellation_event_for_run,
|
||||||
)
|
)
|
||||||
from letta.server.rest_api.utils import capture_sentry_exception
|
from letta.server.rest_api.utils import capture_sentry_exception
|
||||||
|
from letta.services.conversation_manager import ConversationManager
|
||||||
from letta.services.run_manager import RunManager
|
from letta.services.run_manager import RunManager
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
from letta.utils import safe_create_task
|
from letta.utils import safe_create_task
|
||||||
@@ -102,6 +103,22 @@ class StreamingService:
|
|||||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply conversation-level model override if set (lower priority than request override)
|
||||||
|
if conversation_id and not request.override_model:
|
||||||
|
conversation = await ConversationManager().get_conversation_by_id(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
actor=actor,
|
||||||
|
)
|
||||||
|
if conversation.model:
|
||||||
|
conversation_llm_config = await self.server.get_llm_config_from_handle_async(
|
||||||
|
actor=actor,
|
||||||
|
handle=conversation.model,
|
||||||
|
)
|
||||||
|
if conversation.model_settings is not None:
|
||||||
|
update_params = conversation.model_settings._to_legacy_config_params()
|
||||||
|
conversation_llm_config = conversation_llm_config.model_copy(update=update_params)
|
||||||
|
agent = agent.model_copy(update={"llm_config": conversation_llm_config})
|
||||||
|
|
||||||
# Handle model override if specified in the request
|
# Handle model override if specified in the request
|
||||||
if request.override_model:
|
if request.override_model:
|
||||||
override_llm_config = await self.server.get_llm_config_from_handle_async(
|
override_llm_config = await self.server.get_llm_config_from_handle_async(
|
||||||
|
|||||||
@@ -1141,3 +1141,173 @@ async def test_list_conversation_messages_order_with_pagination(conversation_man
|
|||||||
assert "Message 0" in page_asc[0].content
|
assert "Message 0" in page_asc[0].content
|
||||||
# In descending, first should be "Message 4"
|
# In descending, first should be "Message 4"
|
||||||
assert "Message 4" in page_desc[0].content
|
assert "Message 4" in page_desc[0].content
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================================================================
|
||||||
|
# Model/Model Settings Override Tests
|
||||||
|
# ======================================================================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_conversation_with_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test creating a conversation with a model override."""
|
||||||
|
conversation = await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(summary="Test with model override", model="openai/gpt-4o"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.id is not None
|
||||||
|
assert conversation.model == "openai/gpt-4o"
|
||||||
|
assert conversation.model_settings is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_conversation_with_model_and_settings(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test creating a conversation with model and model_settings."""
|
||||||
|
from letta.schemas.model import OpenAIModelSettings
|
||||||
|
|
||||||
|
settings = OpenAIModelSettings(temperature=0.5)
|
||||||
|
conversation = await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(
|
||||||
|
summary="Test with settings",
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
model_settings=settings,
|
||||||
|
),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.model == "openai/gpt-4o"
|
||||||
|
assert conversation.model_settings is not None
|
||||||
|
assert conversation.model_settings.temperature == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_conversation_without_model_override(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test creating a conversation without model override returns None for model fields."""
|
||||||
|
conversation = await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(summary="No override"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.id is not None
|
||||||
|
assert conversation.model is None
|
||||||
|
assert conversation.model_settings is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_conversation_set_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test updating a conversation to add a model override."""
|
||||||
|
# Create without override
|
||||||
|
conversation = await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(summary="Original"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
assert conversation.model is None
|
||||||
|
|
||||||
|
# Update to add override
|
||||||
|
updated = await conversation_manager.update_conversation(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
conversation_update=UpdateConversation(model="anthropic/claude-3-opus"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated.model == "anthropic/claude-3-opus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_conversation_preserves_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test that updating summary preserves existing model override."""
|
||||||
|
# Create with override
|
||||||
|
conversation = await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(summary="With override", model="openai/gpt-4o"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
assert conversation.model == "openai/gpt-4o"
|
||||||
|
|
||||||
|
# Update summary only
|
||||||
|
updated = await conversation_manager.update_conversation(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
conversation_update=UpdateConversation(summary="New summary"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated.summary == "New summary"
|
||||||
|
assert updated.model == "openai/gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_conversation_includes_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test that retrieving a conversation includes model/model_settings."""
|
||||||
|
from letta.schemas.model import OpenAIModelSettings
|
||||||
|
|
||||||
|
created = await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(
|
||||||
|
summary="Retrieve test",
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
model_settings=OpenAIModelSettings(temperature=0.7),
|
||||||
|
),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieved = await conversation_manager.get_conversation_by_id(
|
||||||
|
conversation_id=created.id,
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert retrieved.model == "openai/gpt-4o"
|
||||||
|
assert retrieved.model_settings is not None
|
||||||
|
assert retrieved.model_settings.temperature == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_conversations_includes_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||||
|
"""Test that listing conversations includes model fields."""
|
||||||
|
await conversation_manager.create_conversation(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
conversation_create=CreateConversation(summary="List test", model="openai/gpt-4o"),
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversations = await conversation_manager.list_conversations(
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(conversations) >= 1
|
||||||
|
conv_with_model = [c for c in conversations if c.summary == "List test"]
|
||||||
|
assert len(conv_with_model) == 1
|
||||||
|
assert conv_with_model[0].model == "openai/gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_conversation_schema_model_validation():
|
||||||
|
"""Test that CreateConversation validates model handle format."""
|
||||||
|
from letta.errors import LettaInvalidArgumentError
|
||||||
|
|
||||||
|
# Valid format should work
|
||||||
|
create = CreateConversation(model="openai/gpt-4o")
|
||||||
|
assert create.model == "openai/gpt-4o"
|
||||||
|
|
||||||
|
# Invalid format should raise
|
||||||
|
with pytest.raises(LettaInvalidArgumentError):
|
||||||
|
CreateConversation(model="invalid-no-slash")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_conversation_schema_model_validation():
|
||||||
|
"""Test that UpdateConversation validates model handle format."""
|
||||||
|
from letta.errors import LettaInvalidArgumentError
|
||||||
|
|
||||||
|
# Valid format should work
|
||||||
|
update = UpdateConversation(model="anthropic/claude-3-opus")
|
||||||
|
assert update.model == "anthropic/claude-3-opus"
|
||||||
|
|
||||||
|
# Invalid format should raise
|
||||||
|
with pytest.raises(LettaInvalidArgumentError):
|
||||||
|
UpdateConversation(model="no-slash")
|
||||||
|
|||||||
Reference in New Issue
Block a user