feat(core): add model/model_settings override fields to conversation create/update (#9607)
This commit is contained in:
committed by
Caren Thomas
parent
a9a6a5f29d
commit
afbc416972
@@ -1141,3 +1141,173 @@ async def test_list_conversation_messages_order_with_pagination(conversation_man
|
||||
assert "Message 0" in page_asc[0].content
|
||||
# In descending, first should be "Message 4"
|
||||
assert "Message 4" in page_desc[0].content
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Model/Model Settings Override Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_with_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation with a model override."""
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test with model override", model="openai/gpt-4o"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.model == "openai/gpt-4o"
|
||||
assert conversation.model_settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_with_model_and_settings(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation with model and model_settings."""
|
||||
from letta.schemas.model import OpenAIModelSettings
|
||||
|
||||
settings = OpenAIModelSettings(temperature=0.5)
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(
|
||||
summary="Test with settings",
|
||||
model="openai/gpt-4o",
|
||||
model_settings=settings,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.model == "openai/gpt-4o"
|
||||
assert conversation.model_settings is not None
|
||||
assert conversation.model_settings.temperature == 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_without_model_override(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation without model override returns None for model fields."""
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="No override"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.model is None
|
||||
assert conversation.model_settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation_set_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test updating a conversation to add a model override."""
|
||||
# Create without override
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Original"),
|
||||
actor=default_user,
|
||||
)
|
||||
assert conversation.model is None
|
||||
|
||||
# Update to add override
|
||||
updated = await conversation_manager.update_conversation(
|
||||
conversation_id=conversation.id,
|
||||
conversation_update=UpdateConversation(model="anthropic/claude-3-opus"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert updated.model == "anthropic/claude-3-opus"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation_preserves_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that updating summary preserves existing model override."""
|
||||
# Create with override
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="With override", model="openai/gpt-4o"),
|
||||
actor=default_user,
|
||||
)
|
||||
assert conversation.model == "openai/gpt-4o"
|
||||
|
||||
# Update summary only
|
||||
updated = await conversation_manager.update_conversation(
|
||||
conversation_id=conversation.id,
|
||||
conversation_update=UpdateConversation(summary="New summary"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert updated.summary == "New summary"
|
||||
assert updated.model == "openai/gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_conversation_includes_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that retrieving a conversation includes model/model_settings."""
|
||||
from letta.schemas.model import OpenAIModelSettings
|
||||
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(
|
||||
summary="Retrieve test",
|
||||
model="openai/gpt-4o",
|
||||
model_settings=OpenAIModelSettings(temperature=0.7),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
retrieved = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert retrieved.model == "openai/gpt-4o"
|
||||
assert retrieved.model_settings is not None
|
||||
assert retrieved.model_settings.temperature == 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_includes_model(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that listing conversations includes model fields."""
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="List test", model="openai/gpt-4o"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
conversations = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(conversations) >= 1
|
||||
conv_with_model = [c for c in conversations if c.summary == "List test"]
|
||||
assert len(conv_with_model) == 1
|
||||
assert conv_with_model[0].model == "openai/gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_schema_model_validation():
|
||||
"""Test that CreateConversation validates model handle format."""
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
|
||||
# Valid format should work
|
||||
create = CreateConversation(model="openai/gpt-4o")
|
||||
assert create.model == "openai/gpt-4o"
|
||||
|
||||
# Invalid format should raise
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
CreateConversation(model="invalid-no-slash")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation_schema_model_validation():
|
||||
"""Test that UpdateConversation validates model handle format."""
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
|
||||
# Valid format should work
|
||||
update = UpdateConversation(model="anthropic/claude-3-opus")
|
||||
assert update.model == "anthropic/claude-3-opus"
|
||||
|
||||
# Invalid format should raise
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
UpdateConversation(model="no-slash")
|
||||
|
||||
Reference in New Issue
Block a user