feat: Add max_steps parameter to agent export (#3828)

This commit is contained in:
Matthew Zhou
2025-08-08 13:38:29 -07:00
committed by GitHub
parent 85a7d136c9
commit 57526bf7d6
6 changed files with 1131 additions and 15 deletions

View File

@@ -1,6 +1,7 @@
from typing import Dict from typing import Dict, Optional
from marshmallow import fields, post_dump, pre_load from marshmallow import fields, post_dump, pre_load
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import letta import letta
@@ -15,6 +16,7 @@ from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigFie
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
from letta.settings import DatabaseChoice, settings
class MarshmallowAgentSchema(BaseSchema): class MarshmallowAgentSchema(BaseSchema):
@@ -41,9 +43,10 @@ class MarshmallowAgentSchema(BaseSchema):
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema)) tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
tags = fields.List(fields.Nested(SerializedAgentTagSchema)) tags = fields.List(fields.Nested(SerializedAgentTagSchema))
def __init__(self, *args, session: sessionmaker, actor: User, **kwargs): def __init__(self, *args, session: sessionmaker, actor: User, max_steps: Optional[int] = None, **kwargs):
super().__init__(*args, actor=actor, **kwargs) super().__init__(*args, actor=actor, **kwargs)
self.session = session self.session = session
self.max_steps = max_steps
# Propagate session and actor to nested schemas automatically # Propagate session and actor to nested schemas automatically
for field in self.fields.values(): for field in self.fields.values():
@@ -64,16 +67,103 @@ class MarshmallowAgentSchema(BaseSchema):
with db_registry.session() as session: with db_registry.session() as session:
agent_id = data.get("id") agent_id = data.get("id")
msgs = (
session.query(MessageModel) if self.max_steps is not None:
.filter( # first, always get the system message
MessageModel.agent_id == agent_id, system_msg = (
MessageModel.organization_id == self.actor.organization_id, session.query(MessageModel)
.filter(
MessageModel.agent_id == agent_id,
MessageModel.organization_id == self.actor.organization_id,
MessageModel.role == "system",
)
.order_by(MessageModel.sequence_id.asc())
.first()
) )
.order_by(MessageModel.sequence_id.asc())
.all() if settings.database_engine is DatabaseChoice.POSTGRES:
) # efficient PostgreSQL approach using subquery
# overwrite the “messages” key with a fully serialized list user_msg_subquery = (
session.query(MessageModel.sequence_id)
.filter(
MessageModel.agent_id == agent_id,
MessageModel.organization_id == self.actor.organization_id,
MessageModel.role == "user",
)
.order_by(MessageModel.sequence_id.desc())
.limit(self.max_steps)
.subquery()
)
# get the minimum sequence_id from the subquery
cutoff_sequence_id = session.query(func.min(user_msg_subquery.c.sequence_id)).scalar()
if cutoff_sequence_id:
# get messages from cutoff, excluding system message to avoid duplicates
step_msgs = (
session.query(MessageModel)
.filter(
MessageModel.agent_id == agent_id,
MessageModel.organization_id == self.actor.organization_id,
MessageModel.sequence_id >= cutoff_sequence_id,
MessageModel.role != "system",
)
.order_by(MessageModel.sequence_id.asc())
.all()
)
# combine system message with step messages
msgs = [system_msg] + step_msgs if system_msg else step_msgs
else:
# no user messages, just return system message
msgs = [system_msg] if system_msg else []
else:
# sqlite approach: get all user messages first, then get messages from cutoff
user_messages = (
session.query(MessageModel.sequence_id)
.filter(
MessageModel.agent_id == agent_id,
MessageModel.organization_id == self.actor.organization_id,
MessageModel.role == "user",
)
.order_by(MessageModel.sequence_id.desc())
.limit(self.max_steps)
.all()
)
if user_messages:
# get the minimum sequence_id
cutoff_sequence_id = min(msg.sequence_id for msg in user_messages)
# get messages from cutoff, excluding system message to avoid duplicates
step_msgs = (
session.query(MessageModel)
.filter(
MessageModel.agent_id == agent_id,
MessageModel.organization_id == self.actor.organization_id,
MessageModel.sequence_id >= cutoff_sequence_id,
MessageModel.role != "system",
)
.order_by(MessageModel.sequence_id.asc())
.all()
)
# combine system message with step messages
msgs = [system_msg] + step_msgs if system_msg else step_msgs
else:
# no user messages, just return system message
msgs = [system_msg] if system_msg else []
else:
# if no limit, get all messages in ascending order
msgs = (
session.query(MessageModel)
.filter(
MessageModel.agent_id == agent_id,
MessageModel.organization_id == self.actor.organization_id,
)
.order_by(MessageModel.sequence_id.asc())
.all()
)
# overwrite the "messages" key with a fully serialized list
data[self.FIELD_MESSAGES] = [SerializedMessageSchema(session=self.session, actor=self.actor).dump(m) for m in msgs] data[self.FIELD_MESSAGES] = [SerializedMessageSchema(session=self.session, actor=self.actor).dump(m) for m in msgs]
return data return data

View File

@@ -146,6 +146,7 @@ class IndentedORJSONResponse(Response):
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized") @router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
def export_agent_serialized( def export_agent_serialized(
agent_id: str, agent_id: str,
max_steps: int = 100,
server: "SyncServer" = Depends(get_letta_server), server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), actor_id: str | None = Header(None, alias="user_id"),
# do not remove, used to autogeneration of spec # do not remove, used to autogeneration of spec
@@ -158,7 +159,7 @@ def export_agent_serialized(
actor = server.user_manager.get_user_or_default(user_id=actor_id) actor = server.user_manager.get_user_or_default(user_id=actor_id)
try: try:
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor) agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
return agent.model_dump() return agent.model_dump()
except NoResultFound: except NoResultFound:
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")

View File

@@ -1446,10 +1446,10 @@ class AgentManager:
@enforce_types @enforce_types
@trace_method @trace_method
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema: def serialize(self, agent_id: str, actor: PydanticUser, max_steps: Optional[int] = None) -> AgentSchema:
with db_registry.session() as session: with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
schema = MarshmallowAgentSchema(session=session, actor=actor) schema = MarshmallowAgentSchema(session=session, actor=actor, max_steps=max_steps)
data = schema.dump(agent) data = schema.dump(agent)
return AgentSchema(**data) return AgentSchema(**data)

View File

@@ -111,7 +111,7 @@ def weather_tool_func():
""" """
Fetches the current weather for a given location. Fetches the current weather for a given location.
Parameters: Args:
location (str): The location to get the weather for. location (str): The location to get the weather for.
Returns: Returns:

File diff suppressed because one or more lines are too long

View File

@@ -711,3 +711,51 @@ def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, oth
agent_id=copied_agent_id, agent_id=copied_agent_id,
input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")], input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
) )
def test_serialize_with_max_steps(server, server_url, default_user, other_user):
"""Test that max_steps parameter correctly limits messages by conversation steps."""
# load agent from file with pre-populated messages
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "max_messages.af")
with open(file_path, "rb") as f:
files = {"file": ("max_messages.af", f, "application/json")}
form_data = {
"append_copy_suffix": "false",
"override_existing_tools": "false",
}
response = requests.post(
f"{server_url}/v1/agents/import",
headers={"user_id": default_user.id},
files=files,
data=form_data,
)
assert response.status_code == 200, f"Failed to upload agent: {response.text}"
agent_data = response.json()
agent_id = agent_data["id"]
# test with default max_steps (should use None, returning all messages)
full_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user)
total_messages = len(full_result.messages)
assert total_messages == 31, f"Expected 31 messages, got {total_messages}"
# test with max_steps=2 (should return messages from the last 2 user messages onward)
limited_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user, max_steps=2)
limited_user_count = sum(1 for msg in limited_result.messages if msg.role == "user")
assert limited_user_count == 2, f"Expected 2 user messages (steps), got {limited_user_count}"
assert len(limited_result.messages) == 2 * 3 + 1
# verify agent can still receive messages after being deserialized with limited steps
agent_copy = server.agent_manager.deserialize(limited_result, actor=other_user, append_copy_suffix=True)
response = server.send_messages(
actor=other_user, agent_id=agent_copy.id, input_messages=[MessageCreate(role=MessageRole.user, content="Hello!")]
)
assert response is not None and response.step_count > 0, "Agent should be able to receive and respond to messages"
# test with max_steps=0 (should return only system message)
empty_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user, max_steps=0)
assert len(empty_result.messages) == 1, f"Expected 1 message (system), got {len(empty_result.messages)}"
assert empty_result.messages[0].role == "system", "The only message should be the system message"