diff --git a/fern/openapi.json b/fern/openapi.json index 60bdeec6..11f1d366 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -7498,6 +7498,27 @@ "title": "Include Err" }, "description": "Whether to include error messages and error statuses. For debugging purposes only." + }, + { + "name": "message_types", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageType" + } + }, + { + "type": "null" + } + ], + "description": "Filter to only return specified message types. If None (default), returns all message types.", + "title": "Message Types" + }, + "description": "Filter to only return specified message types. If None (default), returns all message types." } ], "responses": { diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 50181fb3..2e60def3 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1391,6 +1391,10 @@ async def list_messages( include_err: bool | None = Query( None, description="Whether to include error messages and error statuses. For debugging purposes only." ), + message_types: list[MessageType] | None = Query( + None, + description="Filter to only return specified message types. If None (default), returns all message types.", + ), headers: HeaderParams = Depends(get_headers), ): """ @@ -1410,6 +1414,7 @@ async def list_messages( assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, include_err=include_err, + message_types=message_types, actor=actor, ) diff --git a/letta/server/server.py b/letta/server/server.py index 4f606035..7b79e288 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -817,6 +817,7 @@ class SyncServer(object): assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, include_err: Optional[bool] = None, + message_types: Optional[List[MessageType]] = None, ) -> Union[List[Message], List[LettaMessage]]: records = await self.message_manager.list_messages( agent_id=agent_id, @@ -844,6 +845,11 @@ class SyncServer(object): text_is_assistant_message=text_is_assistant_message, ) + # Filter by message_types if specified + if message_types: + message_types_set = set(message_types) + records = [msg for msg in records if msg.message_type in message_types_set] + if reverse: records = records[::-1] diff --git a/tests/managers/test_message_manager.py b/tests/managers/test_message_manager.py index b62b32a8..92f5d864 100644 --- a/tests/managers/test_message_manager.py +++ b/tests/managers/test_message_manager.py @@ -75,7 +75,7 @@ from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert from letta.schemas.job import BatchJob, Job, Job as PydanticJob, JobUpdate, LettaRequestConfig -from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage +from letta.schemas.letta_message import MessageType, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem @@ -1086,3 +1086,112 @@ async def test_convert_assistant_message_with_dict_content(server: SyncServer, s assert isinstance(assistant_msg_nested.content, str) parsed_nested = json.loads(assistant_msg_nested.content) assert parsed_nested == {"status": "success", "data": {"count": 42, "items": ["a", "b"]}, "meta": None} + + +@pytest.mark.asyncio +async def test_get_agent_recall_with_message_types_filter(server: SyncServer, sarah_agent, default_user): + """ + Test that get_agent_recall_async correctly filters messages by message_types. + """ + # Create messages of different types for the agent + # User message + await server.message_manager.create_many_messages_async( + [ + PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Hello from user!")], + ), + ], + actor=default_user, + ) + + # Assistant message with tool call (will become tool_call_message) + tool_calls = [ + OpenAIToolCall( + id="call_test_1", + type="function", + function=OpenAIFunction( + name="send_message", + arguments='{"message": "Hello back!"}', + ), + ), + ] + await server.message_manager.create_many_messages_async( + [ + PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.assistant, + content=[TextContent(text="Thinking about this...")], + tool_calls=tool_calls, + ), + ], + actor=default_user, + ) + + # Tool return message + await server.message_manager.create_many_messages_async( + [ + PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.tool, + tool_call_id="call_test_1", + content=[TextContent(text='{"status": "OK"}')], + ), + ], + actor=default_user, + ) + + # Get all messages (no filter) - should have multiple types + all_messages = await server.get_agent_recall_async( + agent_id=sarah_agent.id, + actor=default_user, + return_message_object=False, + ) + assert len(all_messages) > 0 + + # Collect all unique message types in the result + all_message_types = set(msg.message_type for msg in all_messages) + # Should have at least system, user, and some assistant-related messages + assert MessageType.system_message in all_message_types or MessageType.user_message in all_message_types + + # Filter for only user messages + user_only = await server.get_agent_recall_async( + agent_id=sarah_agent.id, + actor=default_user, + return_message_object=False, + message_types=[MessageType.user_message], + ) + assert len(user_only) > 0 + for msg in user_only: + assert msg.message_type == MessageType.user_message + + # Filter for only system messages + system_only = await server.get_agent_recall_async( + agent_id=sarah_agent.id, + actor=default_user, + return_message_object=False, + message_types=[MessageType.system_message], + ) + for msg in system_only: + assert msg.message_type == MessageType.system_message + + # Filter for multiple types (user and system) + user_and_system = await server.get_agent_recall_async( + agent_id=sarah_agent.id, + actor=default_user, + return_message_object=False, + message_types=[MessageType.user_message, MessageType.system_message], + ) + for msg in user_and_system: + assert msg.message_type in (MessageType.user_message, MessageType.system_message) + + # Filter for tool_return_message + tool_return_only = await server.get_agent_recall_async( + agent_id=sarah_agent.id, + actor=default_user, + return_message_object=False, + message_types=[MessageType.tool_return_message], + ) + for msg in tool_return_only: + assert msg.message_type == MessageType.tool_return_message