From 3e6729214b9d502f718eb179fbb0e51ef658a197 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 6 Jun 2025 15:48:27 -0700 Subject: [PATCH] feat: add `include_return_message_types` to `LettaRequest` to filter down requests (#2666) --- letta/agents/helpers.py | 15 +++- letta/agents/letta_agent.py | 26 +++++-- letta/agents/voice_sleeptime_agent.py | 15 +++- letta/groups/sleeptime_multi_agent_v2.py | 14 +++- letta/schemas/letta_request.py | 6 ++ letta/server/rest_api/routers/v1/agents.py | 10 ++- letta/server/server.py | 8 ++- tests/test_sdk_client.py | 83 ++++++++++++++++++++++ 8 files changed, 164 insertions(+), 13 deletions(-) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 0d653813..5e96996a 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -1,8 +1,9 @@ import uuid import xml.etree.ElementTree as ET -from typing import List, Tuple +from typing import List, Optional, Tuple from letta.schemas.agent import AgentState +from letta.schemas.letta_message import MessageType from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate from letta.schemas.usage import LettaUsageStatistics @@ -12,16 +13,26 @@ from letta.services.message_manager import MessageManager def _create_letta_response( - new_in_context_messages: list[Message], use_assistant_message: bool, usage: LettaUsageStatistics + new_in_context_messages: list[Message], + use_assistant_message: bool, + usage: LettaUsageStatistics, + include_return_message_types: Optional[List[MessageType]] = None, ) -> LettaResponse: """ Converts the newly created/persisted messages into a LettaResponse. """ # NOTE: hacky solution to avoid returning heartbeat messages and the original user message filter_user_messages = [m for m in new_in_context_messages if m.role != "user"] + + # Convert to Letta messages first response_messages = Message.to_letta_messages_from_list( messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False ) + + # Apply message type filtering if specified + if include_return_message_types is not None: + response_messages = [msg for msg in response_messages if msg.message_type in include_return_message_types] + return LettaResponse(messages=response_messages, usage=usage) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 93a575c1..79e64407 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -30,6 +30,7 @@ from letta.otel.metric_registry import MetricRegistry from letta.otel.tracing import log_event, trace_method, tracer from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig @@ -121,6 +122,7 @@ class LettaAgent(BaseAgent): max_steps: int = 10, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ) -> LettaResponse: agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor @@ -132,7 +134,10 @@ class LettaAgent(BaseAgent): request_start_timestamp_ns=request_start_timestamp_ns, ) return _create_letta_response( - new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage + new_in_context_messages=new_in_context_messages, + use_assistant_message=use_assistant_message, + usage=usage, + include_return_message_types=include_return_message_types, ) @trace_method @@ -142,6 +147,7 @@ class LettaAgent(BaseAgent): max_steps: int = 10, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ): agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor @@ -250,8 +256,12 @@ class LettaAgent(BaseAgent): letta_messages = Message.to_letta_messages_from_list( filter_user_messages, use_assistant_message=use_assistant_message, reverse=False ) + for message in letta_messages: - yield f"data: {message.model_dump_json()}\n\n" + if not include_return_message_types: + yield f"data: {message.model_dump_json()}\n\n" + elif include_return_message_types and message.message_type in include_return_message_types: + yield f"data: {message.model_dump_json()}\n\n" if not should_continue: break @@ -409,6 +419,7 @@ class LettaAgent(BaseAgent): max_steps: int = 10, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ) -> AsyncGenerator[str, None]: """ Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens. @@ -486,7 +497,12 @@ class LettaAgent(BaseAgent): request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) first_chunk = False - yield f"data: {chunk.model_dump_json()}\n\n" + if include_return_message_types is None: + # return all data + yield f"data: {chunk.model_dump_json()}\n\n" + elif include_return_message_types and chunk.message_type in include_return_message_types: + # filter down returned data + yield f"data: {chunk.model_dump_json()}\n\n" # update usage usage.step_count += 1 @@ -563,7 +579,9 @@ class LettaAgent(BaseAgent): tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] if not (use_assistant_message and tool_return.name == "send_message"): - yield f"data: {tool_return.model_dump_json()}\n\n" + # Apply message type filtering if specified + if include_return_message_types is None or tool_return.message_type in include_return_message_types: + yield f"data: {tool_return.model_dump_json()}\n\n" if not should_continue: break diff --git a/letta/agents/voice_sleeptime_agent.py b/letta/agents/voice_sleeptime_agent.py index c8ebeb98..1d5abfde 100644 --- a/letta/agents/voice_sleeptime_agent.py +++ b/letta/agents/voice_sleeptime_agent.py @@ -7,7 +7,7 @@ from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.block import BlockUpdate from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage +from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType from letta.schemas.letta_response import LettaResponse from letta.schemas.message import MessageCreate from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule @@ -59,7 +59,13 @@ class VoiceSleeptimeAgent(LettaAgent): def update_message_transcript(self, message_transcripts: List[str]): self.message_transcripts = message_transcripts - async def step(self, input_messages: List[MessageCreate], max_steps: int = 20, use_assistant_message: bool = True) -> LettaResponse: + async def step( + self, + input_messages: List[MessageCreate], + max_steps: int = 20, + use_assistant_message: bool = True, + include_return_message_types: Optional[List[MessageType]] = None, + ) -> LettaResponse: """ Process the user's input message, allowing the model to call memory-related tools until it decides to stop and provide a final response. @@ -86,7 +92,10 @@ class VoiceSleeptimeAgent(LettaAgent): ) return _create_letta_response( - new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage + new_in_context_messages=new_in_context_messages, + use_assistant_message=use_assistant_message, + usage=usage, + include_return_message_types=include_return_message_types, ) @trace_method diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 587a8ef8..c88a9977 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -9,6 +9,7 @@ from letta.otel.tracing import trace_method from letta.schemas.enums import JobStatus from letta.schemas.group import Group, ManagerType from letta.schemas.job import JobUpdate +from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate @@ -63,6 +64,7 @@ class SleeptimeMultiAgentV2(BaseAgent): max_steps: int = 10, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ) -> LettaResponse: run_ids = [] @@ -87,7 +89,10 @@ class SleeptimeMultiAgentV2(BaseAgent): ) # Perform foreground agent step response = await foreground_agent.step( - input_messages=new_messages, max_steps=max_steps, use_assistant_message=use_assistant_message + input_messages=new_messages, + max_steps=max_steps, + use_assistant_message=use_assistant_message, + include_return_message_types=include_return_message_types, ) # Get last response messages @@ -129,8 +134,11 @@ class SleeptimeMultiAgentV2(BaseAgent): max_steps: int = 10, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ): - response = await self.step(input_messages, max_steps, use_assistant_message) + response = await self.step( + input_messages, max_steps, use_assistant_message, request_start_timestamp_ns, include_return_message_types + ) for message in response.messages: yield f"data: {message.model_dump_json()}\n\n" @@ -144,6 +152,7 @@ class SleeptimeMultiAgentV2(BaseAgent): max_steps: int = 10, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ) -> AsyncGenerator[str, None]: # Prepare new messages new_messages = [] @@ -170,6 +179,7 @@ class SleeptimeMultiAgentV2(BaseAgent): max_steps=max_steps, use_assistant_message=use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=include_return_message_types, ): yield chunk diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 6774b359..b7c4117f 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -3,6 +3,7 @@ from typing import List, Optional from pydantic import BaseModel, Field, HttpUrl from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.schemas.letta_message import MessageType from letta.schemas.message import MessageCreate @@ -21,6 +22,11 @@ class LettaRequest(BaseModel): description="The name of the message argument in the designated message tool.", ) + # filter to only return specific message types + include_return_message_types: Optional[List[MessageType]] = Field( + default=None, description="Only return specified message types in the response. If `None` (default) returns all messages." + ) + class LettaStreamingRequest(LettaRequest): stream_tokens: bool = Field( diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 4fc29bad..b7dfeb9a 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -23,7 +23,7 @@ from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig -from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion +from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory @@ -704,6 +704,7 @@ async def send_message( max_steps=10, use_assistant_message=request.use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, ) else: result = await server.send_message_to_agent( @@ -716,6 +717,7 @@ async def send_message( use_assistant_message=request.use_assistant_message, assistant_message_tool_name=request.assistant_message_tool_name, assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, ) return result @@ -791,6 +793,7 @@ async def send_message_streaming( max_steps=10, use_assistant_message=request.use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, ), media_type="text/event-stream", ) @@ -801,6 +804,7 @@ async def send_message_streaming( max_steps=10, use_assistant_message=request.use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, ), media_type="text/event-stream", ) @@ -816,6 +820,7 @@ async def send_message_streaming( assistant_message_tool_name=request.assistant_message_tool_name, assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, ) return result @@ -830,6 +835,7 @@ async def process_message_background( use_assistant_message: bool, assistant_message_tool_name: str, assistant_message_tool_kwarg: str, + include_return_message_types: Optional[List[MessageType]] = None, ) -> None: """Background task to process the message and update job status.""" try: @@ -845,6 +851,7 @@ async def process_message_background( assistant_message_tool_kwarg=assistant_message_tool_kwarg, metadata={"job_id": job_id}, # Pass job_id through metadata request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=include_return_message_types, ) # Update job status to completed @@ -912,6 +919,7 @@ async def send_message_async( use_assistant_message=request.use_assistant_message, assistant_message_tool_name=request.assistant_message_tool_name, assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, ) return run diff --git a/letta/server/server.py b/letta/server/server.py index 3b417542..1ee944d3 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -45,7 +45,7 @@ from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager from letta.schemas.job import Job, JobUpdate -from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage +from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType, ToolReturnMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig @@ -2237,6 +2237,7 @@ class SyncServer(Server): assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, metadata: Optional[dict] = None, request_start_timestamp_ns: Optional[int] = None, + include_return_message_types: Optional[List[MessageType]] = None, ) -> Union[StreamingResponse, LettaResponse]: """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" # TODO: @charles is this the correct way to handle? @@ -2342,6 +2343,11 @@ class SyncServer(Server): # Get rid of the stream status messages filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] + + # Apply message type filtering if specified + if include_return_message_types is not None: + filtered_stream = [msg for msg in filtered_stream if msg.message_type in include_return_message_types] + usage = await task # By default the stream will be messages of type LettaMessage or LettaLegacyMessage diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index befab5df..943e779b 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -679,3 +679,86 @@ def test_many_blocks(client: LettaSDKClient): client.agents.delete(agent1.id) client.agents.delete(agent2.id) + + +# cases: steam, async, token stream, sync +@pytest.mark.parametrize("message_create", ["stream_step", "token_stream", "sync"]) +def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, message_create: str): + """Test that the include_return_message_types parameter works""" + + def verify_message_types(messages, message_types): + for message in messages: + assert message.message_type in message_types + + message = "My name is actually Sarah" + message_types = ["reasoning_message", "tool_call_message"] + agent = client.agents.create( + memory_blocks=[ + CreateBlock(label="user", value="Name: Charles"), + ], + model="letta/letta-free", + embedding="letta/letta-free", + ) + + if message_create == "stream_step": + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content=message, + ), + ], + include_return_message_types=message_types, + ) + messages = [message for message in list(response) if message.message_type != "usage_statistics"] + verify_message_types(messages, message_types) + + elif message_create == "async": + response = client.agents.messages.create_async( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content=message, + ) + ], + include_return_message_types=message_types, + ) + # wait to finish + while response.status != "completed": + time.sleep(1) + response = client.runs.retrieve(run_id=response.id) + messages = client.runs.messages.list(run_id=response.id) + verify_message_types(messages, message_types) + + elif message_create == "token_stream": + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content=message, + ), + ], + include_return_message_types=message_types, + ) + messages = [message for message in list(response) if message.message_type != "usage_statistics"] + verify_message_types(messages, message_types) + + elif message_create == "sync": + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content=message, + ), + ], + include_return_message_types=message_types, + ) + messages = response.messages + verify_message_types(messages, message_types) + + # cleanup + client.agents.delete(agent.id)