diff --git a/fern/openapi.json b/fern/openapi.json index 3ae41301..b3538ac0 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -37275,7 +37275,7 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/letta__schemas__message__ToolReturn" + "$ref": "#/components/schemas/letta__schemas__message__ToolReturn-Output" }, "type": "array" }, @@ -37391,7 +37391,7 @@ "$ref": "#/components/schemas/ApprovalReturn" }, { - "$ref": "#/components/schemas/letta__schemas__message__ToolReturn" + "$ref": "#/components/schemas/letta__schemas__message__ToolReturn-Output" } ] }, @@ -46069,7 +46069,7 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/letta__schemas__message__ToolReturn" + "$ref": "#/components/schemas/letta__schemas__message__ToolReturn-Input" }, "type": "array" }, @@ -46131,7 +46131,7 @@ "$ref": "#/components/schemas/ApprovalReturn" }, { - "$ref": "#/components/schemas/letta__schemas__message__ToolReturn" + "$ref": "#/components/schemas/letta__schemas__message__ToolReturn-Input" } ] }, @@ -46374,8 +46374,19 @@ "default": "tool" }, "tool_return": { - "type": "string", - "title": "Tool Return" + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/LettaToolReturnContentUnion" + }, + "type": "array" + }, + { + "type": "string" + } + ], + "title": "Tool Return", + "description": "The tool return value - either a string or list of content parts (text/image)" }, "status": { "type": "string", @@ -46783,7 +46794,7 @@ "title": "UpdateStreamableHTTPMCPServer", "description": "Update schema for Streamable HTTP MCP server - all fields optional" }, - "letta__schemas__message__ToolReturn": { + "letta__schemas__message__ToolReturn-Input": { "properties": { "tool_call_id": { "anyOf": [ @@ -46836,12 +46847,117 @@ { "type": "string" }, + { + "items": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextContent" + }, + { + "$ref": "#/components/schemas/ImageContent" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "image": "#/components/schemas/ImageContent", + "text": "#/components/schemas/TextContent" + } + } + }, + "type": "array" + }, { "type": "null" } ], "title": "Func Response", - "description": "The function response string" + "description": "The function response - either a string or list of content parts (text/image)" + } + }, + "type": "object", + "required": ["status"], + "title": "ToolReturn" + }, + "letta__schemas__message__ToolReturn-Output": { + "properties": { + "tool_call_id": { + "anyOf": [ + {}, + { + "type": "null" + } + ], + "title": "Tool Call Id", + "description": "The ID for the tool call" + }, + "status": { + "type": "string", + "enum": ["success", "error"], + "title": "Status", + "description": "The status of the tool call" + }, + "stdout": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stdout", + "description": "Captured stdout (e.g. prints, logs) from the tool invocation" + }, + "stderr": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stderr", + "description": "Captured stderr from the tool invocation" + }, + "func_response": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextContent" + }, + { + "$ref": "#/components/schemas/ImageContent" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "image": "#/components/schemas/ImageContent", + "text": "#/components/schemas/TextContent" + } + } + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Func Response", + "description": "The function response - either a string or list of content parts (text/image)" } }, "type": "object", @@ -47330,6 +47446,23 @@ } } }, + "LettaToolReturnContentUnion": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextContent" + }, + { + "$ref": "#/components/schemas/ImageContent" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/TextContent", + "image": "#/components/schemas/ImageContent" + } + } + }, "LettaUserMessageContentUnion": { "oneOf": [ { diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index b17f4ae9..dffe50ff 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -235,7 +235,7 @@ async def _prepare_in_context_messages_no_persist_async( "Please send a regular message to interact with the agent." ) validate_approval_tool_call_ids(current_in_context_messages[-1], input_messages[0]) - new_in_context_messages = create_approval_response_message_from_input( + new_in_context_messages = await create_approval_response_message_from_input( agent_state=agent_state, input_message=input_messages[0], run_id=run_id ) if len(input_messages) > 1: diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 6250bdb3..f4e142df 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -166,3 +166,61 @@ async def _convert_message_create_to_message( batch_item_id=message_create.batch_item_id, run_id=run_id, ) + + +async def _resolve_url_to_base64(url: str) -> tuple[str, str]: + """Resolve URL to base64 data and media type.""" + if url.startswith("file://"): + parsed = urlparse(url) + file_path = unquote(parsed.path) + image_bytes = await asyncio.to_thread(lambda: open(file_path, "rb").read()) + media_type, _ = mimetypes.guess_type(file_path) + media_type = media_type or "image/jpeg" + else: + image_bytes, media_type = await _fetch_image_from_url(url) + media_type = media_type or mimetypes.guess_type(url)[0] or "image/png" + + image_data = base64.standard_b64encode(image_bytes).decode("utf-8") + return image_data, media_type + + +async def resolve_tool_return_images(func_response: str | list) -> str | list: + """Resolve URL and LettaImage sources to base64 for tool returns.""" + if isinstance(func_response, str): + return func_response + + resolved = [] + for part in func_response: + if isinstance(part, ImageContent): + if part.source.type == ImageSourceType.url: + image_data, media_type = await _resolve_url_to_base64(part.source.url) + part.source = Base64Image(media_type=media_type, data=image_data) + elif part.source.type == ImageSourceType.letta and not part.source.data: + pass + resolved.append(part) + elif isinstance(part, TextContent): + resolved.append(part) + elif isinstance(part, dict): + if part.get("type") == "image" and part.get("source", {}).get("type") == "url": + url = part["source"].get("url") + if url: + image_data, media_type = await _resolve_url_to_base64(url) + resolved.append( + ImageContent( + source=Base64Image( + media_type=media_type, + data=image_data, + detail=part.get("source", {}).get("detail"), + ) + ) + ) + else: + resolved.append(part) + elif part.get("type") == "text": + resolved.append(TextContent(text=part.get("text", ""))) + else: + resolved.append(part) + else: + resolved.append(part) + + return resolved diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index a403460c..712071ae 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -7,8 +7,10 @@ from pydantic import BaseModel, Field, field_serializer, field_validator from letta.schemas.letta_message_content import ( LettaAssistantMessageContentUnion, + LettaToolReturnContentUnion, LettaUserMessageContentUnion, get_letta_assistant_message_content_union_str_json_schema, + get_letta_tool_return_content_union_str_json_schema, get_letta_user_message_content_union_str_json_schema, ) @@ -35,7 +37,11 @@ class ApprovalReturn(MessageReturn): class ToolReturn(MessageReturn): type: Literal[MessageReturnType.tool] = Field(default=MessageReturnType.tool, description="The message type to be created.") - tool_return: str + tool_return: Union[str, List[LettaToolReturnContentUnion]] = Field( + ..., + description="The tool return value - either a string or list of content parts (text/image)", + json_schema_extra=get_letta_tool_return_content_union_str_json_schema(), + ) status: Literal["success", "error"] tool_call_id: str stdout: Optional[List[str]] = None diff --git a/letta/schemas/letta_message_content.py b/letta/schemas/letta_message_content.py index 24265777..7c62ebd3 100644 --- a/letta/schemas/letta_message_content.py +++ b/letta/schemas/letta_message_content.py @@ -138,6 +138,48 @@ def get_letta_user_message_content_union_str_json_schema(): } +# ------------------------------- +# Tool Return Content Types +# ------------------------------- + + +LettaToolReturnContentUnion = Annotated[ + Union[TextContent, ImageContent], + Field(discriminator="type"), +] + + +def create_letta_tool_return_content_union_schema(): + return { + "oneOf": [ + {"$ref": "#/components/schemas/TextContent"}, + {"$ref": "#/components/schemas/ImageContent"}, + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/TextContent", + "image": "#/components/schemas/ImageContent", + }, + }, + } + + +def get_letta_tool_return_content_union_str_json_schema(): + """Schema that accepts either string or list of content parts for tool returns.""" + return { + "anyOf": [ + { + "type": "array", + "items": { + "$ref": "#/components/schemas/LettaToolReturnContentUnion", + }, + }, + {"type": "string"}, + ], + } + + # ------------------------------- # Assistant Content Types # ------------------------------- diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 0af6d2f0..102b8824 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -50,6 +50,7 @@ from letta.schemas.letta_message_content import ( ImageContent, ImageSourceType, LettaMessageContentUnion, + LettaToolReturnContentUnion, OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, @@ -71,6 +72,34 @@ def truncate_tool_return(content: Optional[str], limit: Optional[int]) -> Option return content[:limit] + f"... [truncated {len(content) - limit} chars]" +def _get_text_from_part(part: Union[TextContent, ImageContent, dict]) -> Optional[str]: + """Extract text from a content part, returning None for images.""" + if isinstance(part, TextContent): + return part.text + elif isinstance(part, dict) and part.get("type") == "text": + return part.get("text", "") + return None + + +def tool_return_to_text(func_response: Optional[Union[str, List]]) -> Optional[str]: + """Convert tool return content to text, replacing images with placeholders.""" + if func_response is None: + return None + if isinstance(func_response, str): + return func_response + + text_parts = [text for part in func_response if (text := _get_text_from_part(part))] + image_count = sum( + 1 for part in func_response if isinstance(part, ImageContent) or (isinstance(part, dict) and part.get("type") == "image") + ) + + result = "\n".join(text_parts) + if image_count > 0: + placeholder = "[Image omitted]" if image_count == 1 else f"[{image_count} images omitted]" + result = (result + " " + placeholder) if result else placeholder + return result if result else None + + def add_inner_thoughts_to_tool_call( tool_call: OpenAIToolCall, inner_thoughts: str, @@ -786,8 +815,14 @@ class Message(BaseMessage): for tool_return in self.tool_returns: parsed_data = self._parse_tool_response(tool_return.func_response) + # Preserve multi-modal content (ToolReturn supports Union[str, List]) + if isinstance(tool_return.func_response, list): + tool_return_value = tool_return.func_response + else: + tool_return_value = parsed_data["message"] + tool_return_obj = LettaToolReturn( - tool_return=parsed_data["message"], + tool_return=tool_return_value, status=parsed_data["status"], tool_call_id=tool_return.tool_call_id, stdout=tool_return.stdout, @@ -801,11 +836,18 @@ class Message(BaseMessage): first_tool_return = all_tool_returns[0] + # Convert deprecated string-only field to text (preserve images in tool_returns list) + deprecated_tool_return_text = ( + tool_return_to_text(first_tool_return.tool_return) + if isinstance(first_tool_return.tool_return, list) + else first_tool_return.tool_return + ) + return ToolReturnMessage( id=self.id, date=self.created_at, # deprecated top-level fields populated from first tool return - tool_return=first_tool_return.tool_return, + tool_return=deprecated_tool_return_text, status=first_tool_return.status, tool_call_id=first_tool_return.tool_call_id, stdout=first_tool_return.stdout, @@ -840,11 +882,11 @@ class Message(BaseMessage): """Check if message has exactly one text content item.""" return self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent) - def _parse_tool_response(self, response_text: str) -> dict: + def _parse_tool_response(self, response_text: Union[str, List]) -> dict: """Parse tool response JSON and extract message and status. Args: - response_text: Raw JSON response text + response_text: Raw JSON response text OR list of content parts (for multi-modal) Returns: Dictionary with 'message' and 'status' keys @@ -852,6 +894,14 @@ class Message(BaseMessage): Raises: ValueError: If JSON parsing fails """ + # Handle multi-modal content (list with text/images) + if isinstance(response_text, list): + text_representation = tool_return_to_text(response_text) or "[Multi-modal content]" + return { + "message": text_representation, + "status": "success", + } + try: function_return = parse_json(response_text) return { @@ -1301,7 +1351,9 @@ class Message(BaseMessage): tool_return = self.tool_returns[0] if not tool_return.tool_call_id: raise TypeError("OpenAI API requires tool_call_id to be set.") - func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars) + # Convert to text first (replaces images with placeholders), then truncate + func_response_text = tool_return_to_text(tool_return.func_response) + func_response = truncate_tool_return(func_response_text, tool_return_truncation_chars) openai_message = { "content": func_response, "role": self.role, @@ -1356,8 +1408,9 @@ class Message(BaseMessage): for tr in m.tool_returns: if not tr.tool_call_id: raise TypeError("ToolReturn came back without a tool_call_id.") - # Ensure explicit tool_returns are truncated for Chat Completions - func_response = truncate_tool_return(tr.func_response, tool_return_truncation_chars) + # Convert multi-modal to text (images → placeholders), then truncate + func_response_text = tool_return_to_text(tr.func_response) + func_response = truncate_tool_return(func_response_text, tool_return_truncation_chars) result.append( { "content": func_response, @@ -1456,17 +1509,17 @@ class Message(BaseMessage): ) elif self.role == "tool": - # Handle tool returns - similar pattern to Anthropic + # Handle tool returns - supports images via content arrays if self.tool_returns: for tool_return in self.tool_returns: if not tool_return.tool_call_id: raise TypeError("OpenAI Responses API requires tool_call_id to be set.") - func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars) + output = self._tool_return_to_responses_output(tool_return.func_response, tool_return_truncation_chars) message_dicts.append( { "type": "function_call_output", "call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id, - "output": func_response, + "output": output, } ) else: @@ -1534,6 +1587,50 @@ class Message(BaseMessage): return None + @staticmethod + def _image_dict_to_data_url(part: dict) -> Optional[str]: + """Convert image dict to data URL.""" + source = part.get("source", {}) + if source.get("type") == "base64" and source.get("data"): + media_type = source.get("media_type", "image/png") + return f"data:{media_type};base64,{source['data']}" + elif source.get("type") == "url": + return source.get("url") + return None + + @staticmethod + def _tool_return_to_responses_output( + func_response: Optional[Union[str, List]], + tool_return_truncation_chars: Optional[int] = None, + ) -> Union[str, List[dict]]: + """Convert tool return to OpenAI Responses API format.""" + if func_response is None: + return "" + if isinstance(func_response, str): + return truncate_tool_return(func_response, tool_return_truncation_chars) or "" + + output_parts: List[dict] = [] + for part in func_response: + if isinstance(part, TextContent): + text = truncate_tool_return(part.text, tool_return_truncation_chars) or "" + output_parts.append({"type": "input_text", "text": text}) + elif isinstance(part, ImageContent): + image_url = Message._image_source_to_data_url(part) + if image_url: + detail = getattr(part.source, "detail", None) or "auto" + output_parts.append({"type": "input_image", "image_url": image_url, "detail": detail}) + elif isinstance(part, dict): + if part.get("type") == "text": + text = truncate_tool_return(part.get("text", ""), tool_return_truncation_chars) or "" + output_parts.append({"type": "input_text", "text": text}) + elif part.get("type") == "image": + image_url = Message._image_dict_to_data_url(part) + if image_url: + detail = part.get("source", {}).get("detail", "auto") + output_parts.append({"type": "input_image", "image_url": image_url, "detail": detail}) + + return output_parts if output_parts else "" + @staticmethod def to_openai_responses_dicts_from_list( messages: List[Message], @@ -1550,6 +1647,68 @@ class Message(BaseMessage): ) return result + @staticmethod + def _get_base64_image_data(part: Union[ImageContent, dict]) -> Optional[tuple[str, str]]: + """Extract base64 data and media type from ImageContent or dict.""" + if isinstance(part, ImageContent): + source = part.source + if source.type == ImageSourceType.base64: + return source.data, source.media_type + elif source.type == ImageSourceType.letta and getattr(source, "data", None): + return source.data, getattr(source, "media_type", None) or "image/png" + elif isinstance(part, dict) and part.get("type") == "image": + source = part.get("source", {}) + if source.get("type") == "base64" and source.get("data"): + return source["data"], source.get("media_type", "image/png") + return None + + @staticmethod + def _tool_return_to_google_parts( + func_response: Optional[Union[str, List]], + tool_return_truncation_chars: Optional[int] = None, + ) -> tuple[str, List[dict]]: + """Extract text and image parts for Google API format.""" + if isinstance(func_response, str): + return truncate_tool_return(func_response, tool_return_truncation_chars) or "", [] + + text_parts = [] + image_parts = [] + for part in func_response: + if text := _get_text_from_part(part): + text_parts.append(text) + elif image_data := Message._get_base64_image_data(part): + data, media_type = image_data + image_parts.append({"inlineData": {"data": data, "mimeType": media_type}}) + + text = truncate_tool_return("\n".join(text_parts), tool_return_truncation_chars) or "" + if image_parts: + suffix = f"[{len(image_parts)} image(s) attached]" + text = f"{text}\n{suffix}" if text else suffix + + return text, image_parts + + @staticmethod + def _tool_return_to_anthropic_content( + func_response: Optional[Union[str, List]], + tool_return_truncation_chars: Optional[int] = None, + ) -> Union[str, List[dict]]: + """Convert tool return to Anthropic tool_result content format.""" + if func_response is None: + return "" + if isinstance(func_response, str): + return truncate_tool_return(func_response, tool_return_truncation_chars) or "" + + content: List[dict] = [] + for part in func_response: + if text := _get_text_from_part(part): + text = truncate_tool_return(text, tool_return_truncation_chars) or "" + content.append({"type": "text", "text": text}) + elif image_data := Message._get_base64_image_data(part): + data, media_type = image_data + content.append({"type": "image", "source": {"type": "base64", "data": data, "media_type": media_type}}) + + return content if content else "" + def to_anthropic_dict( self, current_model: str, @@ -1759,12 +1918,13 @@ class Message(BaseMessage): f"Message ID: {self.id}, Tool: {self.name or 'unknown'}, " f"Tool return index: {idx}/{len(self.tool_returns)}" ) - func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars) + # Convert to Anthropic format (supports images) + tool_result_content = self._tool_return_to_anthropic_content(tool_return.func_response, tool_return_truncation_chars) content.append( { "type": "tool_result", "tool_use_id": resolved_tool_call_id, - "content": func_response, + "content": tool_result_content, } ) if content: @@ -2003,7 +2163,7 @@ class Message(BaseMessage): elif self.role == "tool": # NOTE: Significantly different tool calling format, more similar to function calling format - # Handle tool returns - similar pattern to Anthropic + # Handle tool returns - Google supports images as sibling inlineData parts if self.tool_returns: parts = [] for tool_return in self.tool_returns: @@ -2013,26 +2173,24 @@ class Message(BaseMessage): # Use the function name if available, otherwise use tool_call_id function_name = self.name if self.name else tool_return.tool_call_id - # Truncate the tool return if needed - func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars) + text_content, image_parts = Message._tool_return_to_google_parts( + tool_return.func_response, tool_return_truncation_chars + ) - # NOTE: Google AI API wants the function response as JSON only, no string try: - function_response = parse_json(func_response) + function_response = parse_json(text_content) except: - function_response = {"function_response": func_response} + function_response = {"function_response": text_content} parts.append( { "functionResponse": { "name": function_name, - "response": { - "name": function_name, # NOTE: name twice... why? - "content": function_response, - }, + "response": {"name": function_name, "content": function_response}, } } ) + parts.extend(image_parts) google_ai_message = { "role": "function", @@ -2325,7 +2483,9 @@ class ToolReturn(BaseModel): status: Literal["success", "error"] = Field(..., description="The status of the tool call") stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation") stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation") - func_response: Optional[str] = Field(None, description="The function response string") + func_response: Optional[Union[str, List[LettaToolReturnContentUnion]]] = Field( + None, description="The function response - either a string or list of content parts (text/image)" + ) class MessageSearchRequest(BaseModel): diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index d56fb4f2..e4ac3333 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -64,6 +64,7 @@ from letta.schemas.letta_message import create_letta_error_message_schema, creat from letta.schemas.letta_message_content import ( create_letta_assistant_message_content_union_schema, create_letta_message_content_union_schema, + create_letta_tool_return_content_union_schema, create_letta_user_message_content_union_schema, ) from letta.server.constants import REST_DEFAULT_PORT @@ -105,6 +106,7 @@ def generate_openapi_schema(app: FastAPI): letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema() letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema() letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema() + letta_docs["components"]["schemas"]["LettaToolReturnContentUnion"] = create_letta_tool_return_content_union_schema() letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema() letta_docs["components"]["schemas"]["LettaErrorMessage"] = create_letta_error_message_schema() diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 25186f8e..66e15572 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -20,7 +20,7 @@ from letta.constants import ( ) from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms -from letta.helpers.message_helper import convert_message_creates_to_messages +from letta.helpers.message_helper import convert_message_creates_to_messages, resolve_tool_return_images from letta.log import get_logger from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry @@ -171,18 +171,26 @@ async def create_input_messages( return messages -def create_approval_response_message_from_input( +async def create_approval_response_message_from_input( agent_state: AgentState, input_message: ApprovalCreate, run_id: Optional[str] = None ) -> List[Message]: - def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn): + async def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn): if isinstance(maybe_tool_return, LettaToolReturn): - packaged_function_response = package_function_response( - maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone - ) + tool_return_content = maybe_tool_return.tool_return + + # Handle tool_return content - can be string or list of content parts (text/image) + if isinstance(tool_return_content, str): + # String content - wrap with package_function_response as before + func_response = package_function_response(maybe_tool_return.status == "success", tool_return_content, agent_state.timezone) + else: + # List of content parts (text/image) - resolve URL images to base64 first + resolved_content = await resolve_tool_return_images(tool_return_content) + func_response = resolved_content + return ToolReturn( tool_call_id=maybe_tool_return.tool_call_id, status=maybe_tool_return.status, - func_response=packaged_function_response, + func_response=func_response, stdout=maybe_tool_return.stdout, stderr=maybe_tool_return.stderr, ) @@ -196,6 +204,11 @@ def create_approval_response_message_from_input( getattr(input_message, "approval_request_id", None), ) + # Process all tool returns concurrently (for async image resolution) + import asyncio + + converted_approvals = await asyncio.gather(*[maybe_convert_tool_return_message(approval) for approval in approvals_list]) + return [ Message( role=MessageRole.approval, @@ -204,7 +217,7 @@ def create_approval_response_message_from_input( approval_request_id=input_message.approval_request_id, approve=input_message.approve, denial_reason=input_message.reason, - approvals=[maybe_convert_tool_return_message(approval) for approval in approvals_list], + approvals=list(converted_approvals), run_id=run_id, group_id=input_message.group_id if input_message.group_id diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index be550734..4aedc99e 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -719,7 +719,7 @@ class RunManager: ) # Use the standard function to create properly formatted approval response messages - approval_response_messages = create_approval_response_message_from_input( + approval_response_messages = await create_approval_response_message_from_input( agent_state=agent_state, input_message=approval_input, run_id=run_id, diff --git a/tests/data/secret.png b/tests/data/secret.png new file mode 100644 index 00000000..c75d0884 Binary files /dev/null and b/tests/data/secret.png differ diff --git a/tests/integration_test_multi_modal_tool_returns.py b/tests/integration_test_multi_modal_tool_returns.py new file mode 100644 index 00000000..831913e6 --- /dev/null +++ b/tests/integration_test_multi_modal_tool_returns.py @@ -0,0 +1,408 @@ +""" +Integration tests for multi-modal tool returns (images in tool responses). + +These tests verify that: +1. Models supporting images in tool returns can see and describe image content +2. Models NOT supporting images (e.g., Chat Completions API) receive placeholder text +3. The image data is properly passed through the approval flow + +The test uses a secret.png image containing hidden text that the model must identify. +""" + +import base64 +import os +import uuid + +import pytest +from letta_client import Letta +from letta_client.types.agents import ApprovalRequestMessage, AssistantMessage, ToolCallMessage + +# ------------------------------ +# Constants +# ------------------------------ + +# The secret text embedded in the test image +# This is the actual text visible in secret.png +SECRET_TEXT_IN_IMAGE = "FIREBRAWL" + +# Models that support images in tool returns (Responses API, Anthropic, or Google AI) +MODELS_WITH_IMAGE_SUPPORT = [ + "anthropic/claude-sonnet-4-5-20250929", + "openai/gpt-5", # Uses Responses API + "google_ai/gemini-2.5-flash", # Google AI with vision support +] + +# Models that do NOT support images in tool returns (Chat Completions only) +MODELS_WITHOUT_IMAGE_SUPPORT = [ + "openai/gpt-4o-mini", # Uses Chat Completions API, not Responses +] + + +def _load_secret_image() -> str: + """Loads the secret test image and returns it as base64.""" + image_path = os.path.join(os.path.dirname(__file__), "data/secret.png") + with open(image_path, "rb") as f: + return base64.standard_b64encode(f.read()).decode("utf-8") + + +SECRET_IMAGE_BASE64 = _load_secret_image() + + +def get_image_tool_schema(): + """Returns a client-side tool schema that returns an image.""" + return { + "name": "get_secret_image", + "description": "Retrieves a secret image with hidden text. Call this function to get the image.", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + } + + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture +def client(server_url: str) -> Letta: + """Create a Letta client.""" + return Letta(base_url=server_url) + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +class TestMultiModalToolReturns: + """Test multi-modal (image) content in tool returns.""" + + @pytest.mark.parametrize("model", MODELS_WITH_IMAGE_SUPPORT) + def test_model_can_see_image_in_tool_return(self, client: Letta, model: str) -> None: + """ + Test that models supporting images can see and describe image content + returned from a tool. + + Flow: + 1. User asks agent to get the secret image and tell them what's in it + 2. Agent calls client-side tool, execution pauses + 3. Client provides tool return with image content + 4. Agent processes the image and describes what it sees + 5. Verify the agent mentions the secret text from the image + """ + # Create agent for this test + agent = client.agents.create( + name=f"multimodal_test_{uuid.uuid4().hex[:8]}", + model=model, + embedding="openai/text-embedding-3-small", + include_base_tools=False, + tool_ids=[], + include_base_tool_rules=False, + tool_rules=[], + ) + + try: + tool_schema = get_image_tool_schema() + print(f"\n=== Testing image support with model: {model} ===") + + # Step 1: User asks for the secret image + print("\nStep 1: Asking agent to call get_secret_image tool...") + response1 = client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "role": "user", + "content": "Call the get_secret_image function now.", + } + ], + client_tools=[tool_schema], + ) + + # Validate Step 1: Should pause with approval request + assert response1.stop_reason.stop_reason == "requires_approval", f"Expected requires_approval, got {response1.stop_reason}" + + # Find the approval request with tool call + approval_msg = None + for msg in response1.messages: + if isinstance(msg, ApprovalRequestMessage): + approval_msg = msg + break + + assert approval_msg is not None, f"Expected an ApprovalRequestMessage but got {[type(m).__name__ for m in response1.messages]}" + assert approval_msg.tool_call.name == "get_secret_image" + + print(f"Tool call ID: {approval_msg.tool_call.tool_call_id}") + + # Step 2: Provide tool return with image content + print("\nStep 2: Providing tool return with image...") + + # Build image content as list of content parts + image_content = [ + {"type": "text", "text": "Here is the secret image:"}, + { + "type": "image", + "source": { + "type": "base64", + "data": SECRET_IMAGE_BASE64, + "media_type": "image/png", + }, + }, + ] + + response2 = client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "tool", + "tool_call_id": approval_msg.tool_call.tool_call_id, + "tool_return": image_content, + "status": "success", + }, + ], + }, + ], + ) + + # Validate Step 2: Agent should process the image and respond + print(f"Stop reason: {response2.stop_reason}") + print(f"Messages: {len(response2.messages)}") + + # Find the assistant message with the response + assistant_response = None + for msg in response2.messages: + if isinstance(msg, AssistantMessage): + assistant_response = msg.content + print(f"Assistant response: {assistant_response[:200]}...") + break + + assert assistant_response is not None, "Expected an AssistantMessage with the image description" + + # Verify the model saw the secret text in the image + # The model should mention the secret code if it can see the image + assert SECRET_TEXT_IN_IMAGE in assistant_response.upper() or SECRET_TEXT_IN_IMAGE.lower() in assistant_response.lower(), ( + f"Model should have seen the secret text '{SECRET_TEXT_IN_IMAGE}' in the image, but response was: {assistant_response}" + ) + + print("\nSUCCESS: Model correctly identified secret text in image!") + + finally: + # Cleanup + client.agents.delete(agent_id=agent.id) + + @pytest.mark.parametrize("model", MODELS_WITHOUT_IMAGE_SUPPORT) + def test_model_without_image_support_gets_placeholder(self, client: Letta, model: str) -> None: + """ + Test that models NOT supporting images receive placeholder text + and cannot see the actual image content. + + This verifies that Chat Completions API models (which don't support + images in tool results) get a graceful fallback. + + Flow: + 1. User asks agent to get the secret image + 2. Agent calls client-side tool, execution pauses + 3. Client provides tool return with image content + 4. Agent processes but CANNOT see the image (only placeholder text) + 5. Verify the agent does NOT mention the secret text + """ + # Create agent for this test + agent = client.agents.create( + name=f"no_image_test_{uuid.uuid4().hex[:8]}", + model=model, + embedding="openai/text-embedding-3-small", + include_base_tools=False, + tool_ids=[], + include_base_tool_rules=False, + tool_rules=[], + ) + + try: + tool_schema = get_image_tool_schema() + print(f"\n=== Testing placeholder for model without image support: {model} ===") + + # Step 1: User asks for the secret image + print("\nStep 1: Asking agent to call get_secret_image tool...") + response1 = client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "role": "user", + "content": "Call the get_secret_image function now.", + } + ], + client_tools=[tool_schema], + ) + + # Validate Step 1: Should pause with approval request + assert response1.stop_reason.stop_reason == "requires_approval", f"Expected requires_approval, got {response1.stop_reason}" + + # Find the approval request with tool call + approval_msg = None + for msg in response1.messages: + if isinstance(msg, ApprovalRequestMessage): + approval_msg = msg + break + + assert approval_msg is not None, f"Expected an ApprovalRequestMessage but got {[type(m).__name__ for m in response1.messages]}" + + # Step 2: Provide tool return with image content + print("\nStep 2: Providing tool return with image...") + + image_content = [ + {"type": "text", "text": "Here is the secret image:"}, + { + "type": "image", + "source": { + "type": "base64", + "data": SECRET_IMAGE_BASE64, + "media_type": "image/png", + }, + }, + ] + + response2 = client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "tool", + "tool_call_id": approval_msg.tool_call.tool_call_id, + "tool_return": image_content, + "status": "success", + }, + ], + }, + ], + ) + + # Find the assistant message + assistant_response = None + for msg in response2.messages: + if isinstance(msg, AssistantMessage): + assistant_response = msg.content + print(f"Assistant response: {assistant_response[:200]}...") + break + + assert assistant_response is not None, "Expected an AssistantMessage" + + # Verify the model did NOT see the secret text (it got placeholder instead) + assert ( + SECRET_TEXT_IN_IMAGE not in assistant_response.upper() and SECRET_TEXT_IN_IMAGE.lower() not in assistant_response.lower() + ), ( + f"Model should NOT have seen the secret text '{SECRET_TEXT_IN_IMAGE}' (it doesn't support images), " + f"but response was: {assistant_response}" + ) + + # The model should mention something about image being omitted/not visible + response_lower = assistant_response.lower() + mentions_image_issue = any( + phrase in response_lower + for phrase in ["image", "omitted", "cannot see", "can't see", "unable to", "not able to", "no image"] + ) + + print("\nSUCCESS: Model correctly did not see the secret (image support not available)") + if mentions_image_issue: + print("Model acknowledged it cannot see the image content") + + finally: + # Cleanup + client.agents.delete(agent_id=agent.id) + + +class TestMultiModalToolReturnsSerialization: + """Test that multi-modal tool returns serialize/deserialize correctly.""" + + @pytest.mark.parametrize("model", MODELS_WITH_IMAGE_SUPPORT[:1]) # Just test one model + def test_tool_return_with_image_persists_in_db(self, client: Letta, model: str) -> None: + """ + Test that tool returns with images are correctly persisted and + can be retrieved from the database. + """ + agent = client.agents.create( + name=f"persist_test_{uuid.uuid4().hex[:8]}", + model=model, + embedding="openai/text-embedding-3-small", + include_base_tools=False, + tool_ids=[], + include_base_tool_rules=False, + tool_rules=[], + ) + + try: + tool_schema = get_image_tool_schema() + + # Trigger tool call + response1 = client.agents.messages.create( + agent_id=agent.id, + messages=[{"role": "user", "content": "Call the get_secret_image tool."}], + client_tools=[tool_schema], + ) + + assert response1.stop_reason.stop_reason == "requires_approval" + + approval_msg = None + for msg in response1.messages: + if isinstance(msg, ApprovalRequestMessage): + approval_msg = msg + break + + assert approval_msg is not None + + # Provide image tool return + image_content = [ + {"type": "text", "text": "Image result"}, + { + "type": "image", + "source": { + "type": "base64", + "data": SECRET_IMAGE_BASE64, + "media_type": "image/png", + }, + }, + ] + + response2 = client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "tool", + "tool_call_id": approval_msg.tool_call.tool_call_id, + "tool_return": image_content, + "status": "success", + }, + ], + }, + ], + ) + + # Verify we got a response + assert response2.stop_reason is not None + + # Retrieve messages from DB and verify they persisted + messages_from_db = client.agents.messages.list(agent_id=agent.id) + + # Look for the tool return message in the persisted messages + found_tool_return = False + for msg in messages_from_db.items: + # Check if this is a tool return message that might contain our image + if hasattr(msg, "tool_returns") and msg.tool_returns: + found_tool_return = True + break + + # The tool return should have been saved + print(f"Found {len(messages_from_db.items)} messages in DB") + print(f"Tool return message found: {found_tool_return}") + + finally: + client.agents.delete(agent_id=agent.id)