diff --git a/alembic/versions/b2c3d4e5f6a8_add_llm_config_to_conversations.py b/alembic/versions/b2c3d4e5f6a8_add_llm_config_to_conversations.py new file mode 100644 index 00000000..b8e94dc5 --- /dev/null +++ b/alembic/versions/b2c3d4e5f6a8_add_llm_config_to_conversations.py @@ -0,0 +1,29 @@ +"""Add model and model_settings columns to conversations table for model overrides + +Revision ID: b2c3d4e5f6a8 +Revises: 3e54e2fa2f7e +Create Date: 2026-02-23 02:50:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b2c3d4e5f6a8" +down_revision: Union[str, None] = "3e54e2fa2f7e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("conversations", sa.Column("model", sa.String(), nullable=True)) + op.add_column("conversations", sa.Column("model_settings", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("conversations", "model_settings") + op.drop_column("conversations", "model") diff --git a/letta/orm/conversation.py b/letta/orm/conversation.py index a3fe7a9f..d7d9a254 100644 --- a/letta/orm/conversation.py +++ b/letta/orm/conversation.py @@ -1,18 +1,22 @@ import uuid from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import ForeignKey, Index, String +from pydantic import TypeAdapter +from sqlalchemy import JSON, ForeignKey, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.conversation import Conversation as PydanticConversation +from letta.schemas.model import ModelSettingsUnion if TYPE_CHECKING: from letta.orm.agent import Agent from letta.orm.block import Block from letta.orm.conversation_messages import ConversationMessage +_model_settings_adapter = TypeAdapter(ModelSettingsUnion) + class Conversation(SqlalchemyBase, OrganizationMixin): """Conversations that can be created on an agent for concurrent messaging.""" @@ -27,6 +31,12 @@ class Conversation(SqlalchemyBase, OrganizationMixin): id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}") agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False) summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation") + model: Mapped[Optional[str]] = mapped_column( + String, nullable=True, doc="Model handle override for this conversation (format: provider/model-name)" + ) + model_settings: Mapped[Optional[dict]] = mapped_column( + JSON, nullable=True, doc="Model settings override for this conversation (provider-specific settings)" + ) # Relationships agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise") @@ -55,4 +65,6 @@ class Conversation(SqlalchemyBase, OrganizationMixin): created_by_id=self.created_by_id, last_updated_by_id=self.last_updated_by_id, isolated_block_ids=[b.id for b in self.isolated_blocks] if self.isolated_blocks else [], + model=self.model, + model_settings=_model_settings_adapter.validate_python(self.model_settings) if self.model_settings else None, ) diff --git a/letta/schemas/conversation.py b/letta/schemas/conversation.py index c2a94007..0f6bf7f8 100644 --- a/letta/schemas/conversation.py +++ b/letta/schemas/conversation.py @@ -1,8 +1,10 @@ from typing import List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator +from letta.errors import LettaInvalidArgumentError from letta.schemas.letta_base import OrmMetadataBase +from letta.schemas.model import ModelSettingsUnion class Conversation(OrmMetadataBase): @@ -18,6 +20,14 @@ class Conversation(OrmMetadataBase): default_factory=list, description="IDs of blocks that are isolated (specific to this conversation, overriding agent defaults).", ) + model: Optional[str] = Field( + None, + description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.", + ) + model_settings: Optional[ModelSettingsUnion] = Field( + None, + description="The model settings for this conversation (overrides agent's model settings).", + ) class CreateConversation(BaseModel): @@ -29,9 +39,49 @@ class CreateConversation(BaseModel): description="List of block labels that should be isolated (conversation-specific) rather than shared across conversations. " "New blocks will be created as copies of the agent's blocks with these labels.", ) + model: Optional[str] = Field( + None, + description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.", + ) + model_settings: Optional[ModelSettingsUnion] = Field( + None, + description="The model settings for this conversation (overrides agent's model settings).", + ) + + @field_validator("model") + @classmethod + def validate_model(cls, model: Optional[str]) -> Optional[str]: + if not model: + return model + if "/" not in model: + raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model") + provider_name, model_name = model.split("/", 1) + if not provider_name or not model_name: + raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model") + return model class UpdateConversation(BaseModel): """Request model for updating a conversation.""" summary: Optional[str] = Field(None, description="A summary of the conversation.") + model: Optional[str] = Field( + None, + description="The model handle for this conversation (overrides agent's model). Format: provider/model-name.", + ) + model_settings: Optional[ModelSettingsUnion] = Field( + None, + description="The model settings for this conversation (overrides agent's model settings).", + ) + + @field_validator("model") + @classmethod + def validate_model(cls, model: Optional[str]) -> Optional[str]: + if not model: + return model + if "/" not in model: + raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model") + provider_name, model_name = model.split("/", 1) + if not provider_name or not model_name: + raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model") + return model diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py index d52f54bd..f52e57e5 100644 --- a/letta/server/rest_api/routers/v1/conversations.py +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -250,6 +250,17 @@ async def send_conversation_message( include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"], ) + # Apply conversation-level model override if set (lower priority than request override) + if conversation.model and not request.override_model: + conversation_llm_config = await server.get_llm_config_from_handle_async( + actor=actor, + handle=conversation.model, + ) + if conversation.model_settings is not None: + update_params = conversation.model_settings._to_legacy_config_params() + conversation_llm_config = conversation_llm_config.model_copy(update=update_params) + agent = agent.model_copy(update={"llm_config": conversation_llm_config}) + if request.override_model: override_llm_config = await server.get_llm_config_from_handle_async( actor=actor, diff --git a/letta/services/conversation_manager.py b/letta/services/conversation_manager.py index 41111623..ee0cffd1 100644 --- a/letta/services/conversation_manager.py +++ b/letta/services/conversation_manager.py @@ -54,6 +54,8 @@ class ConversationManager: agent_id=agent_id, summary=conversation_create.summary, organization_id=actor.organization_id, + model=conversation_create.model, + model_settings=conversation_create.model_settings.model_dump() if conversation_create.model_settings else None, ) await conversation.create_async(session, actor=actor) @@ -185,7 +187,11 @@ class ConversationManager: # Set attributes on the model update_data = conversation_update.model_dump(exclude_none=True) for key, value in update_data.items(): - setattr(conversation, key, value) + # model_settings needs to be serialized to dict for the JSON column + if key == "model_settings" and value is not None: + setattr(conversation, key, conversation_update.model_settings.model_dump() if conversation_update.model_settings else value) + else: + setattr(conversation, key, value) # Commit the update updated_conversation = await conversation.update_async( diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index beac6eda..496177ef 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -45,6 +45,7 @@ from letta.server.rest_api.streaming_response import ( get_cancellation_event_for_run, ) from letta.server.rest_api.utils import capture_sentry_exception +from letta.services.conversation_manager import ConversationManager from letta.services.run_manager import RunManager from letta.settings import settings from letta.utils import safe_create_task @@ -102,6 +103,22 @@ class StreamingService: include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"], ) + # Apply conversation-level model override if set (lower priority than request override) + if conversation_id and not request.override_model: + conversation = await ConversationManager().get_conversation_by_id( + conversation_id=conversation_id, + actor=actor, + ) + if conversation.model: + conversation_llm_config = await self.server.get_llm_config_from_handle_async( + actor=actor, + handle=conversation.model, + ) + if conversation.model_settings is not None: + update_params = conversation.model_settings._to_legacy_config_params() + conversation_llm_config = conversation_llm_config.model_copy(update=update_params) + agent = agent.model_copy(update={"llm_config": conversation_llm_config}) + # Handle model override if specified in the request if request.override_model: override_llm_config = await self.server.get_llm_config_from_handle_async( diff --git a/tests/managers/test_conversation_manager.py b/tests/managers/test_conversation_manager.py index acb807f8..8bf2d5f8 100644 --- a/tests/managers/test_conversation_manager.py +++ b/tests/managers/test_conversation_manager.py @@ -1141,3 +1141,173 @@ async def test_list_conversation_messages_order_with_pagination(conversation_man assert "Message 0" in page_asc[0].content # In descending, first should be "Message 4" assert "Message 4" in page_desc[0].content + + +# ====================================================================================================================== +# Model/Model Settings Override Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_create_conversation_with_model(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test creating a conversation with a model override.""" + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test with model override", model="openai/gpt-4o"), + actor=default_user, + ) + + assert conversation.id is not None + assert conversation.model == "openai/gpt-4o" + assert conversation.model_settings is None + + +@pytest.mark.asyncio +async def test_create_conversation_with_model_and_settings(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test creating a conversation with model and model_settings.""" + from letta.schemas.model import OpenAIModelSettings + + settings = OpenAIModelSettings(temperature=0.5) + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation( + summary="Test with settings", + model="openai/gpt-4o", + model_settings=settings, + ), + actor=default_user, + ) + + assert conversation.model == "openai/gpt-4o" + assert conversation.model_settings is not None + assert conversation.model_settings.temperature == 0.5 + + +@pytest.mark.asyncio +async def test_create_conversation_without_model_override(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test creating a conversation without model override returns None for model fields.""" + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="No override"), + actor=default_user, + ) + + assert conversation.id is not None + assert conversation.model is None + assert conversation.model_settings is None + + +@pytest.mark.asyncio +async def test_update_conversation_set_model(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test updating a conversation to add a model override.""" + # Create without override + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Original"), + actor=default_user, + ) + assert conversation.model is None + + # Update to add override + updated = await conversation_manager.update_conversation( + conversation_id=conversation.id, + conversation_update=UpdateConversation(model="anthropic/claude-3-opus"), + actor=default_user, + ) + + assert updated.model == "anthropic/claude-3-opus" + + +@pytest.mark.asyncio +async def test_update_conversation_preserves_model(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test that updating summary preserves existing model override.""" + # Create with override + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="With override", model="openai/gpt-4o"), + actor=default_user, + ) + assert conversation.model == "openai/gpt-4o" + + # Update summary only + updated = await conversation_manager.update_conversation( + conversation_id=conversation.id, + conversation_update=UpdateConversation(summary="New summary"), + actor=default_user, + ) + + assert updated.summary == "New summary" + assert updated.model == "openai/gpt-4o" + + +@pytest.mark.asyncio +async def test_retrieve_conversation_includes_model(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test that retrieving a conversation includes model/model_settings.""" + from letta.schemas.model import OpenAIModelSettings + + created = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation( + summary="Retrieve test", + model="openai/gpt-4o", + model_settings=OpenAIModelSettings(temperature=0.7), + ), + actor=default_user, + ) + + retrieved = await conversation_manager.get_conversation_by_id( + conversation_id=created.id, + actor=default_user, + ) + + assert retrieved.model == "openai/gpt-4o" + assert retrieved.model_settings is not None + assert retrieved.model_settings.temperature == 0.7 + + +@pytest.mark.asyncio +async def test_list_conversations_includes_model(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test that listing conversations includes model fields.""" + await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="List test", model="openai/gpt-4o"), + actor=default_user, + ) + + conversations = await conversation_manager.list_conversations( + agent_id=sarah_agent.id, + actor=default_user, + ) + + assert len(conversations) >= 1 + conv_with_model = [c for c in conversations if c.summary == "List test"] + assert len(conv_with_model) == 1 + assert conv_with_model[0].model == "openai/gpt-4o" + + +@pytest.mark.asyncio +async def test_create_conversation_schema_model_validation(): + """Test that CreateConversation validates model handle format.""" + from letta.errors import LettaInvalidArgumentError + + # Valid format should work + create = CreateConversation(model="openai/gpt-4o") + assert create.model == "openai/gpt-4o" + + # Invalid format should raise + with pytest.raises(LettaInvalidArgumentError): + CreateConversation(model="invalid-no-slash") + + +@pytest.mark.asyncio +async def test_update_conversation_schema_model_validation(): + """Test that UpdateConversation validates model handle format.""" + from letta.errors import LettaInvalidArgumentError + + # Valid format should work + update = UpdateConversation(model="anthropic/claude-3-opus") + assert update.model == "anthropic/claude-3-opus" + + # Invalid format should raise + with pytest.raises(LettaInvalidArgumentError): + UpdateConversation(model="no-slash")