From 2823e4447a4a3f0a1d88d5143a07e5e28b6200ba Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 8 Jun 2025 18:28:01 -0700 Subject: [PATCH] feat: add multi-modal input support (#2590) --- letta/agent.py | 12 +- letta/agents/letta_agent.py | 7 +- letta/helpers/converters.py | 8 +- letta/helpers/message_helper.py | 42 ++++-- letta/llm_api/anthropic_client.py | 52 ++++++++ letta/llm_api/google_vertex_client.py | 45 ++++++- letta/llm_api/openai_client.py | 51 ++++++- letta/schemas/letta_message_content.py | 67 +++++++++- letta/schemas/message.py | 9 ++ .../schemas/openai/chat_completion_request.py | 2 +- letta/services/message_manager.py | 32 +++++ letta/services/passage_manager.py | 1 + tests/integration_test_send_message.py | 126 +++++++++++++++++- 13 files changed, 432 insertions(+), 22 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 153f6f16..ef8f69a9 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -46,7 +46,7 @@ from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_ from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, ProviderType -from letta.schemas.letta_message_content import TextContent +from letta.schemas.letta_message_content import ImageContent, TextContent from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.openai.chat_completion_response import ChatCompletionResponse @@ -369,6 +369,16 @@ class Agent(BaseAgent): ) else: # Fallback to existing flow + for message in message_sequence: + if isinstance(message.content, list): + + def get_fallback_text_content(content): + if isinstance(content, ImageContent): + return TextContent(text="[Image Here]") + return content + + message.content = [get_fallback_text_content(content) for content in message.content] + response = create( llm_config=self.agent_state.llm_config, messages=message_sequence, diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 12f1d369..2f7cbdf8 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -820,7 +820,12 @@ class LettaAgent(BaseAgent): tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True ) - return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call) + return llm_client.build_request_data( + in_context_messages, + agent_state.llm_config, + allowed_tools, + force_tool_call, + ) @trace_method async def _handle_ai_response( diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 36d47fda..78c8b45c 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -12,6 +12,8 @@ from letta.schemas.agent import AgentStepState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderType, ToolRuleType from letta.schemas.letta_message_content import ( + ImageContent, + ImageSourceType, MessageContent, MessageContentType, OmittedReasoningContent, @@ -216,12 +218,13 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten serialized_message_content = [] for content in message_content: if isinstance(content, MessageContent): + if content.type == MessageContentType.image: + assert content.source.type == ImageSourceType.letta, f"Invalid image source type: {content.source.type}" serialized_message_content.append(content.model_dump(mode="json")) elif isinstance(content, dict): serialized_message_content.append(content) # Already a dictionary, leave it as-is else: raise TypeError(f"Unexpected message content type: {type(content)}") - return serialized_message_content @@ -238,6 +241,9 @@ def deserialize_message_content(data: Optional[List[Dict]]) -> List[MessageConte content_type = item.get("type") if content_type == MessageContentType.text: content = TextContent(**item) + elif content_type == MessageContentType.image: + assert item["source"]["type"] == ImageSourceType.letta, f'Invalid image source type: {item["source"]["type"]}' + content = ImageContent(**item) elif content_type == MessageContentType.tool_call: content = ToolCallContent(**item) elif content_type == MessageContentType.tool_return: diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 90d7b680..0c6a0584 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -1,6 +1,11 @@ +import base64 +import mimetypes + +import httpx + from letta import system from letta.schemas.enums import MessageRole -from letta.schemas.letta_message_content import TextContent +from letta.schemas.letta_message_content import Base64Image, ImageContent, ImageSourceType, TextContent from letta.schemas.message import Message, MessageCreate @@ -33,24 +38,39 @@ def _convert_message_create_to_message( # Extract message content if isinstance(message_create.content, str): + assert message_create.content != "", "Message content must not be empty" + message_content = [TextContent(text=message_create.content)] + elif isinstance(message_create.content, list) and len(message_create.content) > 0: message_content = message_create.content - elif message_create.content and len(message_create.content) > 0 and isinstance(message_create.content[0], TextContent): - message_content = message_create.content[0].text else: raise ValueError("Message content is empty or invalid") - # Apply wrapping if needed - if message_create.role not in {MessageRole.user, MessageRole.system}: - raise ValueError(f"Invalid message role: {message_create.role}") - elif message_create.role == MessageRole.user and wrap_user_message: - message_content = system.package_user_message(user_message=message_content) - elif message_create.role == MessageRole.system and wrap_system_message: - message_content = system.package_system_message(system_message=message_content) + assert message_create.role in {MessageRole.user, MessageRole.system}, f"Invalid message role: {message_create.role}" + for content in message_content: + if isinstance(content, TextContent): + # Apply wrapping if needed + if message_create.role == MessageRole.user and wrap_user_message: + content.text = system.package_user_message(user_message=content.text) + elif message_create.role == MessageRole.system and wrap_system_message: + content.text = system.package_system_message(system_message=content.text) + elif isinstance(content, ImageContent): + if content.source.type == ImageSourceType.url: + # Convert URL image to Base64Image if needed + image_response = httpx.get(content.source.url) + image_response.raise_for_status() + image_media_type = image_response.headers.get("content-type") + if not image_media_type: + image_media_type, _ = mimetypes.guess_type(content.source.url) + image_data = base64.standard_b64encode(image_response.content).decode("utf-8") + content.source = Base64Image(media_type=image_media_type, data=image_data) + if content.source.type == ImageSourceType.letta and not content.source.data: + # TODO: hydrate letta image with data from db + pass return Message( agent_id=agent_id, role=message_create.role, - content=[TextContent(text=message_content)] if message_content else [], + content=message_content, name=message_create.name, model=None, # assigned later? tool_calls=None, # irrelevant diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 41580120..3429502e 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -29,6 +29,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.enums import ProviderCategory +from letta.schemas.letta_message_content import MessageContentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool as OpenAITool @@ -251,6 +252,8 @@ class AnthropicClient(LLMClientBase): for m in messages[1:] ] + data["messages"] = fill_image_content_in_messages(data["messages"], messages) + # Ensure first message is user if data["messages"][0]["role"] != "user": data["messages"] = [{"role": "user", "content": DUMMY_FIRST_USER_MESSAGE}] + data["messages"] @@ -656,3 +659,52 @@ def strip_xml_tags_streaming(string: str, tag: Optional[str]) -> str: result = result.replace(part, "") return result + + +def fill_image_content_in_messages(anthropic_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]: + """ + Converts image content to anthropic format. + """ + + if len(anthropic_message_list) != len(pydantic_message_list): + return anthropic_message_list + + new_message_list = [] + for idx in range(len(anthropic_message_list)): + anthropic_message, pydantic_message = anthropic_message_list[idx], pydantic_message_list[idx] + if pydantic_message.role != "user": + new_message_list.append(anthropic_message) + continue + + if not isinstance(pydantic_message.content, list) or ( + len(pydantic_message.content) == 1 and pydantic_message.content.type == MessageContentType.text + ): + new_message_list.append(anthropic_message) + continue + + message_content = [] + for content in pydantic_message.content: + if content.type == MessageContentType.text: + message_content.append( + { + "type": "input_text", + "text": content.text, + } + ) + elif content.type == MessageContentType.image: + message_content.append( + { + "type": "image", + "source": { + "type": content.source.type, + "media_type": content.source.media_type, + "data": content.source.data, + }, + } + ) + else: + raise ValueError(f"Unsupported content type {content.type}") + + new_message_list.append({"role": "user", "input": message_content}) + + return new_message_list diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 25b1e00f..5738a6b3 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -3,7 +3,7 @@ import uuid from typing import List, Optional from google import genai -from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig +from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, Part, ThinkingConfig, ToolConfig from letta.constants import NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time_int @@ -13,6 +13,7 @@ from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens from letta.log import get_logger from letta.otel.tracing import trace_method +from letta.schemas.letta_message_content import MessageContentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool @@ -218,7 +219,7 @@ class GoogleVertexClient(LLMClientBase): ) request_data = { - "contents": contents, + "contents": fill_image_content_in_messages(contents, messages), "config": { "temperature": llm_config.temperature, "max_output_tokens": llm_config.max_tokens, @@ -484,3 +485,43 @@ class GoogleVertexClient(LLMClientBase): "propertyOrdering": ["name", "args"], "required": ["name", "args"], } + + +def fill_image_content_in_messages(google_ai_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]: + """ + Converts image content to openai format. + """ + + if len(google_ai_message_list) != len(pydantic_message_list): + return google_ai_message_list + + new_message_list = [] + for idx in range(len(google_ai_message_list)): + google_ai_message, pydantic_message = google_ai_message_list[idx], pydantic_message_list[idx] + if pydantic_message.role != "user": + new_message_list.append(google_ai_message) + continue + + if not isinstance(pydantic_message.content, list) or ( + len(pydantic_message.content) == 1 and pydantic_message.content[0].type == MessageContentType.text + ): + new_message_list.append(google_ai_message) + continue + + message_content = [] + for content in pydantic_message.content: + if content.type == MessageContentType.text: + message_content.append({"text": content.text}) + elif content.type == MessageContentType.image: + message_content.append( + Part.from_bytes( + mime_type=content.source.media_type, + data=content.source.data, + ) + ) + else: + raise ValueError(f"Unsupported content type {content.type}") + + new_message_list.append({"role": "user", "input": message_content}) + + return new_message_list diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index ec289803..e0ea6187 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -26,6 +26,7 @@ from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.letta_message_content import MessageContentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest @@ -213,7 +214,7 @@ class OpenAIClient(LLMClientBase): data = ChatCompletionRequest( model=model, - messages=openai_message_list, + messages=fill_image_content_in_messages(openai_message_list, messages), tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None, tool_choice=tool_choice, user=str(), @@ -402,3 +403,51 @@ class OpenAIClient(LLMClientBase): # Fallback for unexpected errors return super().handle_llm_error(e) + + +def fill_image_content_in_messages(openai_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]: + """ + Converts image content to openai format. + """ + + if len(openai_message_list) != len(pydantic_message_list): + return openai_message_list + + new_message_list = [] + for idx in range(len(openai_message_list)): + openai_message, pydantic_message = openai_message_list[idx], pydantic_message_list[idx] + if pydantic_message.role != "user": + new_message_list.append(openai_message) + continue + + if not isinstance(pydantic_message.content, list) or ( + len(pydantic_message.content) == 1 and pydantic_message.content[0].type == MessageContentType.text + ): + new_message_list.append(openai_message) + continue + + message_content = [] + for content in pydantic_message.content: + if content.type == MessageContentType.text: + message_content.append( + { + "type": "input_text", + "text": content.text, + } + ) + elif content.type == MessageContentType.image: + message_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{content.source.media_type};base64,{content.source.data}", + "detail": content.source.detail or "auto", + }, + } + ) + else: + raise ValueError(f"Unsupported content type {content.type}") + + new_message_list.append({"role": "user", "content": message_content}) + + return new_message_list diff --git a/letta/schemas/letta_message_content.py b/letta/schemas/letta_message_content.py index 40092698..a9ca2144 100644 --- a/letta/schemas/letta_message_content.py +++ b/letta/schemas/letta_message_content.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field class MessageContentType(str, Enum): text = "text" + image = "image" tool_call = "tool_call" tool_return = "tool_return" reasoning = "reasoning" @@ -18,7 +19,7 @@ class MessageContent(BaseModel): # ------------------------------- -# User Content Types +# Text Content # ------------------------------- @@ -27,8 +28,62 @@ class TextContent(MessageContent): text: str = Field(..., description="The text content of the message.") +# ------------------------------- +# Image Content +# ------------------------------- + + +class ImageSourceType(str, Enum): + url = "url" + base64 = "base64" + letta = "letta" + + +class ImageSource(BaseModel): + type: ImageSourceType = Field(..., description="The source type for the image.") + + +class UrlImage(ImageSource): + type: Literal[ImageSourceType.url] = Field(ImageSourceType.url, description="The source type for the image.") + url: str = Field(..., description="The URL of the image.") + + +class Base64Image(ImageSource): + type: Literal[ImageSourceType.base64] = Field(ImageSourceType.base64, description="The source type for the image.") + media_type: str = Field(..., description="The media type for the image.") + data: str = Field(..., description="The base64 encoded image data.") + detail: Optional[str] = Field( + None, + description="What level of detail to use when processing and understanding the image (low, high, or auto to let the model decide)", + ) + + +class LettaImage(ImageSource): + type: Literal[ImageSourceType.letta] = Field(ImageSourceType.letta, description="The source type for the image.") + file_id: str = Field(..., description="The unique identifier of the image file persisted in storage.") + media_type: Optional[str] = Field(None, description="The media type for the image.") + data: Optional[str] = Field(None, description="The base64 encoded image data.") + detail: Optional[str] = Field( + None, + description="What level of detail to use when processing and understanding the image (low, high, or auto to let the model decide)", + ) + + +ImageSourceUnion = Annotated[Union[UrlImage, Base64Image, LettaImage], Field(discriminator="type")] + + +class ImageContent(MessageContent): + type: Literal[MessageContentType.image] = Field(MessageContentType.image, description="The type of the message.") + source: ImageSourceUnion = Field(..., description="The source of the image.") + + +# ------------------------------- +# User Content Types +# ------------------------------- + + LettaUserMessageContentUnion = Annotated[ - Union[TextContent], + Union[TextContent, ImageContent], Field(discriminator="type"), ] @@ -37,11 +92,13 @@ def create_letta_user_message_content_union_schema(): return { "oneOf": [ {"$ref": "#/components/schemas/TextContent"}, + {"$ref": "#/components/schemas/ImageContent"}, ], "discriminator": { "propertyName": "type", "mapping": { "text": "#/components/schemas/TextContent", + "image": "#/components/schemas/ImageContent", }, }, } @@ -150,7 +207,9 @@ class OmittedReasoningContent(MessageContent): LettaMessageContentUnion = Annotated[ - Union[TextContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent], + Union[ + TextContent, ImageContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent + ], Field(discriminator="type"), ] @@ -159,6 +218,7 @@ def create_letta_message_content_union_schema(): return { "oneOf": [ {"$ref": "#/components/schemas/TextContent"}, + {"$ref": "#/components/schemas/ImageContent"}, {"$ref": "#/components/schemas/ToolCallContent"}, {"$ref": "#/components/schemas/ToolReturnContent"}, {"$ref": "#/components/schemas/ReasoningContent"}, @@ -169,6 +229,7 @@ def create_letta_message_content_union_schema(): "propertyName": "type", "mapping": { "text": "#/components/schemas/TextContent", + "image": "#/components/schemas/ImageContent", "tool_call": "#/components/schemas/ToolCallContent", "tool_return": "#/components/schemas/ToolCallContent", "reasoning": "#/components/schemas/ReasoningContent", diff --git a/letta/schemas/message.py b/letta/schemas/message.py index eec9742a..f1b4d5ef 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -31,6 +31,7 @@ from letta.schemas.letta_message import ( UserMessage, ) from letta.schemas.letta_message_content import ( + ImageContent, LettaMessageContentUnion, OmittedReasoningContent, ReasoningContent, @@ -415,6 +416,8 @@ class Message(BaseMessage): # This is type UserMessage if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text + elif self.content and len(self.content) == 1 and isinstance(self.content[0], ImageContent): + text_content = "[Image Here]" else: raise ValueError(f"Invalid user message (no text object on message): {self.content}") @@ -658,6 +661,8 @@ class Message(BaseMessage): text_content = self.content[0].text elif self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent): text_content = self.content[0].content + elif self.content and len(self.content) == 1 and isinstance(self.content[0], ImageContent): + text_content = "[Image Here]" # Otherwise, check if we have TextContent and multiple other parts elif self.content and len(self.content) > 1: text = [content for content in self.content if isinstance(content, TextContent)] @@ -755,6 +760,8 @@ class Message(BaseMessage): # Check for COT if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text + elif self.content and len(self.content) == 1 and isinstance(self.content[0], ImageContent): + text_content = "[Image Here]" else: text_content = None @@ -872,6 +879,8 @@ class Message(BaseMessage): text_content = self.content[0].text elif self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent): text_content = self.content[0].content + elif self.content and len(self.content) == 1 and isinstance((self.content[0]), ImageContent): + text_content = "[Image Here]" else: text_content = None diff --git a/letta/schemas/openai/chat_completion_request.py b/letta/schemas/openai/chat_completion_request.py index 25b1e15f..f9ae397e 100644 --- a/letta/schemas/openai/chat_completion_request.py +++ b/letta/schemas/openai/chat_completion_request.py @@ -10,7 +10,7 @@ class SystemMessage(BaseModel): class UserMessage(BaseModel): - content: Union[str, List[str]] + content: Union[str, List[str], List[dict]] role: str = "user" name: Optional[str] = None diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 45a6fc6f..34a2e32f 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,4 +1,5 @@ import json +import uuid from typing import List, Optional, Sequence from sqlalchemy import delete, exists, func, select, text @@ -10,10 +11,12 @@ from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method from letta.schemas.enums import MessageRole from letta.schemas.letta_message import LettaMessageUpdateUnion +from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.services.file_manager import FileManager from letta.utils import enforce_types logger = get_logger(__name__) @@ -22,6 +25,10 @@ logger = get_logger(__name__) class MessageManager: """Manager class to handle business logic related to Messages.""" + def __init__(self): + """Initialize the MessageManager.""" + self.file_manager = FileManager() + @enforce_types @trace_method def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: @@ -131,6 +138,31 @@ class MessageManager: if not pydantic_msgs: return [] + for message in pydantic_msgs: + if isinstance(message.content, list): + for content in message.content: + if content.type == MessageContentType.image and content.source.type == ImageSourceType.base64: + # TODO: actually persist image files in db + # file = await self.file_manager.create_file( # TODO: use batch create to prevent multiple db round trips + # db_session=session, + # image_create=FileMetadata( + # user_id=actor.id, # TODO: add field + # source_id= '' # TODO: make optional + # organization_id=actor.organization_id, + # file_type=content.source.media_type, + # processing_status=FileProcessingStatus.COMPLETED, + # content= '' # TODO: should content be added here or in top level text field? + # ), + # actor=actor, + # text=content.source.data, + # ) + file_id_placeholder = "file-" + str(uuid.uuid4()) + content.source = LettaImage( + file_id=file_id_placeholder, + data=content.source.data, + media_type=content.source.media_type, + detail=content.source.detail, + ) orm_messages = self._create_many_preprocess(pydantic_msgs, actor) async with db_registry.async_session() as session: created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor) diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 25b87a11..49362d83 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -521,6 +521,7 @@ class PassageManager: agent_id: str, text: str, actor: PydanticUser, + image_ids: Optional[List[str]] = None, ) -> List[PydanticPassage]: """Insert passage(s) into archival memory""" diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 12f63cf2..ab4f31d1 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -1,3 +1,4 @@ +import base64 import json import os import threading @@ -5,12 +6,23 @@ import time import uuid from typing import Any, Dict, List +import httpx import pytest import requests from dotenv import load_dotenv from letta_client import AsyncLetta, Letta, MessageCreate, Run from letta_client.core.api_error import ApiError -from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage +from letta_client.types import ( + AssistantMessage, + Base64Image, + ImageContent, + LettaUsageStatistics, + ReasoningMessage, + ToolCallMessage, + ToolReturnMessage, + UrlImage, + UserMessage, +) from letta.schemas.agent import AgentState from letta.schemas.llm_config import LLMConfig @@ -138,6 +150,30 @@ USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ otid=USER_MESSAGE_OTID, ) ] +URL_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" +USER_MESSAGE_URL_IMAGE: List[MessageCreate] = [ + MessageCreate( + role="user", + content=[ + ImageContent( + source=UrlImage(url=URL_IMAGE), + ) + ], + otid=USER_MESSAGE_OTID, + ) +] +BASE64_IMAGE = base64.standard_b64encode(httpx.get(URL_IMAGE).content).decode("utf-8") +USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [ + MessageCreate( + role="user", + content=[ + ImageContent( + source=Base64Image(data=BASE64_IMAGE, media_type="image/jpeg"), + ) + ], + otid=USER_MESSAGE_OTID, + ) +] all_configs = [ "openai-gpt-4o-mini.json", # "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop @@ -285,6 +321,42 @@ def assert_tool_call_response( assert isinstance(messages[index], LettaUsageStatistics) +def assert_image_input_response( + messages: List[Any], + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> AssistantMessage. + """ + expected_message_count = 3 if streaming or from_db else 2 + assert len(messages) == expected_message_count + + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + if streaming: + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + def accumulate_chunks(chunks: List[Any]) -> List[Any]: """ Accumulates chunks into a list of messages. @@ -421,6 +493,58 @@ def test_tool_call( assert_tool_call_response(messages_from_db, from_db=True) +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_url_image_input( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that the response messages follow the expected order. + """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=USER_MESSAGE_URL_IMAGE, + ) + assert_image_input_response(response.messages) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_image_input_response(messages_from_db, from_db=True) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_base64_image_input( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that the response messages follow the expected order. + """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=USER_MESSAGE_BASE64_IMAGE, + ) + assert_image_input_response(response.messages) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_image_input_response(messages_from_db, from_db=True) + + @pytest.mark.parametrize( "llm_config", TESTED_LLM_CONFIGS,