feat: add multi-modal input support (#2590)
This commit is contained in:
@@ -46,7 +46,7 @@ from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_message_content import ImageContent, TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
@@ -369,6 +369,16 @@ class Agent(BaseAgent):
|
||||
)
|
||||
else:
|
||||
# Fallback to existing flow
|
||||
for message in message_sequence:
|
||||
if isinstance(message.content, list):
|
||||
|
||||
def get_fallback_text_content(content):
|
||||
if isinstance(content, ImageContent):
|
||||
return TextContent(text="[Image Here]")
|
||||
return content
|
||||
|
||||
message.content = [get_fallback_text_content(content) for content in message.content]
|
||||
|
||||
response = create(
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=message_sequence,
|
||||
|
||||
@@ -820,7 +820,12 @@ class LettaAgent(BaseAgent):
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
|
||||
)
|
||||
|
||||
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
|
||||
return llm_client.build_request_data(
|
||||
in_context_messages,
|
||||
agent_state.llm_config,
|
||||
allowed_tools,
|
||||
force_tool_call,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _handle_ai_response(
|
||||
|
||||
@@ -12,6 +12,8 @@ from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, ToolRuleType
|
||||
from letta.schemas.letta_message_content import (
|
||||
ImageContent,
|
||||
ImageSourceType,
|
||||
MessageContent,
|
||||
MessageContentType,
|
||||
OmittedReasoningContent,
|
||||
@@ -216,12 +218,13 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten
|
||||
serialized_message_content = []
|
||||
for content in message_content:
|
||||
if isinstance(content, MessageContent):
|
||||
if content.type == MessageContentType.image:
|
||||
assert content.source.type == ImageSourceType.letta, f"Invalid image source type: {content.source.type}"
|
||||
serialized_message_content.append(content.model_dump(mode="json"))
|
||||
elif isinstance(content, dict):
|
||||
serialized_message_content.append(content) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
raise TypeError(f"Unexpected message content type: {type(content)}")
|
||||
|
||||
return serialized_message_content
|
||||
|
||||
|
||||
@@ -238,6 +241,9 @@ def deserialize_message_content(data: Optional[List[Dict]]) -> List[MessageConte
|
||||
content_type = item.get("type")
|
||||
if content_type == MessageContentType.text:
|
||||
content = TextContent(**item)
|
||||
elif content_type == MessageContentType.image:
|
||||
assert item["source"]["type"] == ImageSourceType.letta, f'Invalid image source type: {item["source"]["type"]}'
|
||||
content = ImageContent(**item)
|
||||
elif content_type == MessageContentType.tool_call:
|
||||
content = ToolCallContent(**item)
|
||||
elif content_type == MessageContentType.tool_return:
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
import httpx
|
||||
|
||||
from letta import system
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_message_content import Base64Image, ImageContent, ImageSourceType, TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
@@ -33,24 +38,39 @@ def _convert_message_create_to_message(
|
||||
|
||||
# Extract message content
|
||||
if isinstance(message_create.content, str):
|
||||
assert message_create.content != "", "Message content must not be empty"
|
||||
message_content = [TextContent(text=message_create.content)]
|
||||
elif isinstance(message_create.content, list) and len(message_create.content) > 0:
|
||||
message_content = message_create.content
|
||||
elif message_create.content and len(message_create.content) > 0 and isinstance(message_create.content[0], TextContent):
|
||||
message_content = message_create.content[0].text
|
||||
else:
|
||||
raise ValueError("Message content is empty or invalid")
|
||||
|
||||
# Apply wrapping if needed
|
||||
if message_create.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||
elif message_create.role == MessageRole.user and wrap_user_message:
|
||||
message_content = system.package_user_message(user_message=message_content)
|
||||
elif message_create.role == MessageRole.system and wrap_system_message:
|
||||
message_content = system.package_system_message(system_message=message_content)
|
||||
assert message_create.role in {MessageRole.user, MessageRole.system}, f"Invalid message role: {message_create.role}"
|
||||
for content in message_content:
|
||||
if isinstance(content, TextContent):
|
||||
# Apply wrapping if needed
|
||||
if message_create.role == MessageRole.user and wrap_user_message:
|
||||
content.text = system.package_user_message(user_message=content.text)
|
||||
elif message_create.role == MessageRole.system and wrap_system_message:
|
||||
content.text = system.package_system_message(system_message=content.text)
|
||||
elif isinstance(content, ImageContent):
|
||||
if content.source.type == ImageSourceType.url:
|
||||
# Convert URL image to Base64Image if needed
|
||||
image_response = httpx.get(content.source.url)
|
||||
image_response.raise_for_status()
|
||||
image_media_type = image_response.headers.get("content-type")
|
||||
if not image_media_type:
|
||||
image_media_type, _ = mimetypes.guess_type(content.source.url)
|
||||
image_data = base64.standard_b64encode(image_response.content).decode("utf-8")
|
||||
content.source = Base64Image(media_type=image_media_type, data=image_data)
|
||||
if content.source.type == ImageSourceType.letta and not content.source.data:
|
||||
# TODO: hydrate letta image with data from db
|
||||
pass
|
||||
|
||||
return Message(
|
||||
agent_id=agent_id,
|
||||
role=message_create.role,
|
||||
content=[TextContent(text=message_content)] if message_content else [],
|
||||
content=message_content,
|
||||
name=message_create.name,
|
||||
model=None, # assigned later?
|
||||
tool_calls=None, # irrelevant
|
||||
|
||||
@@ -29,6 +29,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.letta_message_content import MessageContentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
@@ -251,6 +252,8 @@ class AnthropicClient(LLMClientBase):
|
||||
for m in messages[1:]
|
||||
]
|
||||
|
||||
data["messages"] = fill_image_content_in_messages(data["messages"], messages)
|
||||
|
||||
# Ensure first message is user
|
||||
if data["messages"][0]["role"] != "user":
|
||||
data["messages"] = [{"role": "user", "content": DUMMY_FIRST_USER_MESSAGE}] + data["messages"]
|
||||
@@ -656,3 +659,52 @@ def strip_xml_tags_streaming(string: str, tag: Optional[str]) -> str:
|
||||
result = result.replace(part, "")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fill_image_content_in_messages(anthropic_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]:
|
||||
"""
|
||||
Converts image content to anthropic format.
|
||||
"""
|
||||
|
||||
if len(anthropic_message_list) != len(pydantic_message_list):
|
||||
return anthropic_message_list
|
||||
|
||||
new_message_list = []
|
||||
for idx in range(len(anthropic_message_list)):
|
||||
anthropic_message, pydantic_message = anthropic_message_list[idx], pydantic_message_list[idx]
|
||||
if pydantic_message.role != "user":
|
||||
new_message_list.append(anthropic_message)
|
||||
continue
|
||||
|
||||
if not isinstance(pydantic_message.content, list) or (
|
||||
len(pydantic_message.content) == 1 and pydantic_message.content.type == MessageContentType.text
|
||||
):
|
||||
new_message_list.append(anthropic_message)
|
||||
continue
|
||||
|
||||
message_content = []
|
||||
for content in pydantic_message.content:
|
||||
if content.type == MessageContentType.text:
|
||||
message_content.append(
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": content.text,
|
||||
}
|
||||
)
|
||||
elif content.type == MessageContentType.image:
|
||||
message_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": content.source.type,
|
||||
"media_type": content.source.media_type,
|
||||
"data": content.source.data,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type {content.type}")
|
||||
|
||||
new_message_list.append({"role": "user", "input": message_content})
|
||||
|
||||
return new_message_list
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, Part, ThinkingConfig, ToolConfig
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
@@ -13,6 +13,7 @@ from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.letta_message_content import MessageContentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@@ -218,7 +219,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"contents": fill_image_content_in_messages(contents, messages),
|
||||
"config": {
|
||||
"temperature": llm_config.temperature,
|
||||
"max_output_tokens": llm_config.max_tokens,
|
||||
@@ -484,3 +485,43 @@ class GoogleVertexClient(LLMClientBase):
|
||||
"propertyOrdering": ["name", "args"],
|
||||
"required": ["name", "args"],
|
||||
}
|
||||
|
||||
|
||||
def fill_image_content_in_messages(google_ai_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]:
|
||||
"""
|
||||
Converts image content to openai format.
|
||||
"""
|
||||
|
||||
if len(google_ai_message_list) != len(pydantic_message_list):
|
||||
return google_ai_message_list
|
||||
|
||||
new_message_list = []
|
||||
for idx in range(len(google_ai_message_list)):
|
||||
google_ai_message, pydantic_message = google_ai_message_list[idx], pydantic_message_list[idx]
|
||||
if pydantic_message.role != "user":
|
||||
new_message_list.append(google_ai_message)
|
||||
continue
|
||||
|
||||
if not isinstance(pydantic_message.content, list) or (
|
||||
len(pydantic_message.content) == 1 and pydantic_message.content[0].type == MessageContentType.text
|
||||
):
|
||||
new_message_list.append(google_ai_message)
|
||||
continue
|
||||
|
||||
message_content = []
|
||||
for content in pydantic_message.content:
|
||||
if content.type == MessageContentType.text:
|
||||
message_content.append({"text": content.text})
|
||||
elif content.type == MessageContentType.image:
|
||||
message_content.append(
|
||||
Part.from_bytes(
|
||||
mime_type=content.source.media_type,
|
||||
data=content.source.data,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type {content.type}")
|
||||
|
||||
new_message_list.append({"role": "user", "input": message_content})
|
||||
|
||||
return new_message_list
|
||||
|
||||
@@ -26,6 +26,7 @@ from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.letta_message_content import MessageContentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
@@ -213,7 +214,7 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
data = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=openai_message_list,
|
||||
messages=fill_image_content_in_messages(openai_message_list, messages),
|
||||
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
|
||||
tool_choice=tool_choice,
|
||||
user=str(),
|
||||
@@ -402,3 +403,51 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
# Fallback for unexpected errors
|
||||
return super().handle_llm_error(e)
|
||||
|
||||
|
||||
def fill_image_content_in_messages(openai_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]:
|
||||
"""
|
||||
Converts image content to openai format.
|
||||
"""
|
||||
|
||||
if len(openai_message_list) != len(pydantic_message_list):
|
||||
return openai_message_list
|
||||
|
||||
new_message_list = []
|
||||
for idx in range(len(openai_message_list)):
|
||||
openai_message, pydantic_message = openai_message_list[idx], pydantic_message_list[idx]
|
||||
if pydantic_message.role != "user":
|
||||
new_message_list.append(openai_message)
|
||||
continue
|
||||
|
||||
if not isinstance(pydantic_message.content, list) or (
|
||||
len(pydantic_message.content) == 1 and pydantic_message.content[0].type == MessageContentType.text
|
||||
):
|
||||
new_message_list.append(openai_message)
|
||||
continue
|
||||
|
||||
message_content = []
|
||||
for content in pydantic_message.content:
|
||||
if content.type == MessageContentType.text:
|
||||
message_content.append(
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": content.text,
|
||||
}
|
||||
)
|
||||
elif content.type == MessageContentType.image:
|
||||
message_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{content.source.media_type};base64,{content.source.data}",
|
||||
"detail": content.source.detail or "auto",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type {content.type}")
|
||||
|
||||
new_message_list.append({"role": "user", "content": message_content})
|
||||
|
||||
return new_message_list
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
image = "image"
|
||||
tool_call = "tool_call"
|
||||
tool_return = "tool_return"
|
||||
reasoning = "reasoning"
|
||||
@@ -18,7 +19,7 @@ class MessageContent(BaseModel):
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# User Content Types
|
||||
# Text Content
|
||||
# -------------------------------
|
||||
|
||||
|
||||
@@ -27,8 +28,62 @@ class TextContent(MessageContent):
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Image Content
|
||||
# -------------------------------
|
||||
|
||||
|
||||
class ImageSourceType(str, Enum):
|
||||
url = "url"
|
||||
base64 = "base64"
|
||||
letta = "letta"
|
||||
|
||||
|
||||
class ImageSource(BaseModel):
|
||||
type: ImageSourceType = Field(..., description="The source type for the image.")
|
||||
|
||||
|
||||
class UrlImage(ImageSource):
|
||||
type: Literal[ImageSourceType.url] = Field(ImageSourceType.url, description="The source type for the image.")
|
||||
url: str = Field(..., description="The URL of the image.")
|
||||
|
||||
|
||||
class Base64Image(ImageSource):
|
||||
type: Literal[ImageSourceType.base64] = Field(ImageSourceType.base64, description="The source type for the image.")
|
||||
media_type: str = Field(..., description="The media type for the image.")
|
||||
data: str = Field(..., description="The base64 encoded image data.")
|
||||
detail: Optional[str] = Field(
|
||||
None,
|
||||
description="What level of detail to use when processing and understanding the image (low, high, or auto to let the model decide)",
|
||||
)
|
||||
|
||||
|
||||
class LettaImage(ImageSource):
|
||||
type: Literal[ImageSourceType.letta] = Field(ImageSourceType.letta, description="The source type for the image.")
|
||||
file_id: str = Field(..., description="The unique identifier of the image file persisted in storage.")
|
||||
media_type: Optional[str] = Field(None, description="The media type for the image.")
|
||||
data: Optional[str] = Field(None, description="The base64 encoded image data.")
|
||||
detail: Optional[str] = Field(
|
||||
None,
|
||||
description="What level of detail to use when processing and understanding the image (low, high, or auto to let the model decide)",
|
||||
)
|
||||
|
||||
|
||||
ImageSourceUnion = Annotated[Union[UrlImage, Base64Image, LettaImage], Field(discriminator="type")]
|
||||
|
||||
|
||||
class ImageContent(MessageContent):
|
||||
type: Literal[MessageContentType.image] = Field(MessageContentType.image, description="The type of the message.")
|
||||
source: ImageSourceUnion = Field(..., description="The source of the image.")
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# User Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
LettaUserMessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Union[TextContent, ImageContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -37,11 +92,13 @@ def create_letta_user_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
{"$ref": "#/components/schemas/ImageContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
"image": "#/components/schemas/ImageContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -150,7 +207,9 @@ class OmittedReasoningContent(MessageContent):
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
Union[TextContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent],
|
||||
Union[
|
||||
TextContent, ImageContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -159,6 +218,7 @@ def create_letta_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
{"$ref": "#/components/schemas/ImageContent"},
|
||||
{"$ref": "#/components/schemas/ToolCallContent"},
|
||||
{"$ref": "#/components/schemas/ToolReturnContent"},
|
||||
{"$ref": "#/components/schemas/ReasoningContent"},
|
||||
@@ -169,6 +229,7 @@ def create_letta_message_content_union_schema():
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
"image": "#/components/schemas/ImageContent",
|
||||
"tool_call": "#/components/schemas/ToolCallContent",
|
||||
"tool_return": "#/components/schemas/ToolCallContent",
|
||||
"reasoning": "#/components/schemas/ReasoningContent",
|
||||
|
||||
@@ -31,6 +31,7 @@ from letta.schemas.letta_message import (
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_message_content import (
|
||||
ImageContent,
|
||||
LettaMessageContentUnion,
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
@@ -415,6 +416,8 @@ class Message(BaseMessage):
|
||||
# This is type UserMessage
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
elif self.content and len(self.content) == 1 and isinstance(self.content[0], ImageContent):
|
||||
text_content = "[Image Here]"
|
||||
else:
|
||||
raise ValueError(f"Invalid user message (no text object on message): {self.content}")
|
||||
|
||||
@@ -658,6 +661,8 @@ class Message(BaseMessage):
|
||||
text_content = self.content[0].text
|
||||
elif self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent):
|
||||
text_content = self.content[0].content
|
||||
elif self.content and len(self.content) == 1 and isinstance(self.content[0], ImageContent):
|
||||
text_content = "[Image Here]"
|
||||
# Otherwise, check if we have TextContent and multiple other parts
|
||||
elif self.content and len(self.content) > 1:
|
||||
text = [content for content in self.content if isinstance(content, TextContent)]
|
||||
@@ -755,6 +760,8 @@ class Message(BaseMessage):
|
||||
# Check for COT
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
elif self.content and len(self.content) == 1 and isinstance(self.content[0], ImageContent):
|
||||
text_content = "[Image Here]"
|
||||
else:
|
||||
text_content = None
|
||||
|
||||
@@ -872,6 +879,8 @@ class Message(BaseMessage):
|
||||
text_content = self.content[0].text
|
||||
elif self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent):
|
||||
text_content = self.content[0].content
|
||||
elif self.content and len(self.content) == 1 and isinstance((self.content[0]), ImageContent):
|
||||
text_content = "[Image Here]"
|
||||
else:
|
||||
text_content = None
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ class SystemMessage(BaseModel):
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
content: Union[str, List[str]]
|
||||
content: Union[str, List[str], List[dict]]
|
||||
role: str = "user"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from sqlalchemy import delete, exists, func, select, text
|
||||
@@ -10,10 +11,12 @@ from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.file_manager import FileManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -22,6 +25,10 @@ logger = get_logger(__name__)
|
||||
class MessageManager:
|
||||
"""Manager class to handle business logic related to Messages."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the MessageManager."""
|
||||
self.file_manager = FileManager()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
@@ -131,6 +138,31 @@ class MessageManager:
|
||||
if not pydantic_msgs:
|
||||
return []
|
||||
|
||||
for message in pydantic_msgs:
|
||||
if isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if content.type == MessageContentType.image and content.source.type == ImageSourceType.base64:
|
||||
# TODO: actually persist image files in db
|
||||
# file = await self.file_manager.create_file( # TODO: use batch create to prevent multiple db round trips
|
||||
# db_session=session,
|
||||
# image_create=FileMetadata(
|
||||
# user_id=actor.id, # TODO: add field
|
||||
# source_id= '' # TODO: make optional
|
||||
# organization_id=actor.organization_id,
|
||||
# file_type=content.source.media_type,
|
||||
# processing_status=FileProcessingStatus.COMPLETED,
|
||||
# content= '' # TODO: should content be added here or in top level text field?
|
||||
# ),
|
||||
# actor=actor,
|
||||
# text=content.source.data,
|
||||
# )
|
||||
file_id_placeholder = "file-" + str(uuid.uuid4())
|
||||
content.source = LettaImage(
|
||||
file_id=file_id_placeholder,
|
||||
data=content.source.data,
|
||||
media_type=content.source.media_type,
|
||||
detail=content.source.detail,
|
||||
)
|
||||
orm_messages = self._create_many_preprocess(pydantic_msgs, actor)
|
||||
async with db_registry.async_session() as session:
|
||||
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor)
|
||||
|
||||
@@ -521,6 +521,7 @@ class PassageManager:
|
||||
agent_id: str,
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
image_ids: Optional[List[str]] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passage(s) into archival memory"""
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
@@ -5,12 +6,23 @@ import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, Letta, MessageCreate, Run
|
||||
from letta_client.core.api_error import ApiError
|
||||
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta_client.types import (
|
||||
AssistantMessage,
|
||||
Base64Image,
|
||||
ImageContent,
|
||||
LettaUsageStatistics,
|
||||
ReasoningMessage,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
UrlImage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -138,6 +150,30 @@ USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
URL_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
|
||||
USER_MESSAGE_URL_IMAGE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
ImageContent(
|
||||
source=UrlImage(url=URL_IMAGE),
|
||||
)
|
||||
],
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
BASE64_IMAGE = base64.standard_b64encode(httpx.get(URL_IMAGE).content).decode("utf-8")
|
||||
USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
ImageContent(
|
||||
source=Base64Image(data=BASE64_IMAGE, media_type="image/jpeg"),
|
||||
)
|
||||
],
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
|
||||
@@ -285,6 +321,42 @@ def assert_tool_call_response(
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_image_input_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 3 if streaming or from_db else 2
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Accumulates chunks into a list of messages.
|
||||
@@ -421,6 +493,58 @@ def test_tool_call(
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_url_image_input(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_URL_IMAGE,
|
||||
)
|
||||
assert_image_input_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_image_input_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_base64_image_input(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_BASE64_IMAGE,
|
||||
)
|
||||
assert_image_input_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_image_input_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
|
||||
Reference in New Issue
Block a user