feat: Add max_steps parameter to agent export (#3828)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import letta
|
||||
@@ -15,6 +16,7 @@ from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigFie
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
|
||||
class MarshmallowAgentSchema(BaseSchema):
|
||||
@@ -41,9 +43,10 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
||||
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
|
||||
|
||||
def __init__(self, *args, session: sessionmaker, actor: User, **kwargs):
|
||||
def __init__(self, *args, session: sessionmaker, actor: User, max_steps: Optional[int] = None, **kwargs):
|
||||
super().__init__(*args, actor=actor, **kwargs)
|
||||
self.session = session
|
||||
self.max_steps = max_steps
|
||||
|
||||
# Propagate session and actor to nested schemas automatically
|
||||
for field in self.fields.values():
|
||||
@@ -64,16 +67,103 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
|
||||
with db_registry.session() as session:
|
||||
agent_id = data.get("id")
|
||||
msgs = (
|
||||
session.query(MessageModel)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
|
||||
if self.max_steps is not None:
|
||||
# first, always get the system message
|
||||
system_msg = (
|
||||
session.query(MessageModel)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
MessageModel.role == "system",
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.asc())
|
||||
.first()
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.asc())
|
||||
.all()
|
||||
)
|
||||
# overwrite the “messages” key with a fully serialized list
|
||||
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# efficient PostgreSQL approach using subquery
|
||||
user_msg_subquery = (
|
||||
session.query(MessageModel.sequence_id)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
MessageModel.role == "user",
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.desc())
|
||||
.limit(self.max_steps)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# get the minimum sequence_id from the subquery
|
||||
cutoff_sequence_id = session.query(func.min(user_msg_subquery.c.sequence_id)).scalar()
|
||||
|
||||
if cutoff_sequence_id:
|
||||
# get messages from cutoff, excluding system message to avoid duplicates
|
||||
step_msgs = (
|
||||
session.query(MessageModel)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
MessageModel.sequence_id >= cutoff_sequence_id,
|
||||
MessageModel.role != "system",
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.asc())
|
||||
.all()
|
||||
)
|
||||
# combine system message with step messages
|
||||
msgs = [system_msg] + step_msgs if system_msg else step_msgs
|
||||
else:
|
||||
# no user messages, just return system message
|
||||
msgs = [system_msg] if system_msg else []
|
||||
else:
|
||||
# sqlite approach: get all user messages first, then get messages from cutoff
|
||||
user_messages = (
|
||||
session.query(MessageModel.sequence_id)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
MessageModel.role == "user",
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.desc())
|
||||
.limit(self.max_steps)
|
||||
.all()
|
||||
)
|
||||
|
||||
if user_messages:
|
||||
# get the minimum sequence_id
|
||||
cutoff_sequence_id = min(msg.sequence_id for msg in user_messages)
|
||||
|
||||
# get messages from cutoff, excluding system message to avoid duplicates
|
||||
step_msgs = (
|
||||
session.query(MessageModel)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
MessageModel.sequence_id >= cutoff_sequence_id,
|
||||
MessageModel.role != "system",
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.asc())
|
||||
.all()
|
||||
)
|
||||
# combine system message with step messages
|
||||
msgs = [system_msg] + step_msgs if system_msg else step_msgs
|
||||
else:
|
||||
# no user messages, just return system message
|
||||
msgs = [system_msg] if system_msg else []
|
||||
else:
|
||||
# if no limit, get all messages in ascending order
|
||||
msgs = (
|
||||
session.query(MessageModel)
|
||||
.filter(
|
||||
MessageModel.agent_id == agent_id,
|
||||
MessageModel.organization_id == self.actor.organization_id,
|
||||
)
|
||||
.order_by(MessageModel.sequence_id.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# overwrite the "messages" key with a fully serialized list
|
||||
data[self.FIELD_MESSAGES] = [SerializedMessageSchema(session=self.session, actor=self.actor).dump(m) for m in msgs]
|
||||
|
||||
return data
|
||||
|
||||
@@ -146,6 +146,7 @@ class IndentedORJSONResponse(Response):
|
||||
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
|
||||
def export_agent_serialized(
|
||||
agent_id: str,
|
||||
max_steps: int = 100,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
# do not remove, used to autogeneration of spec
|
||||
@@ -158,7 +159,7 @@ def export_agent_serialized(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
||||
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
|
||||
return agent.model_dump()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
|
||||
@@ -1446,10 +1446,10 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
||||
def serialize(self, agent_id: str, actor: PydanticUser, max_steps: Optional[int] = None) -> AgentSchema:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor, max_steps=max_steps)
|
||||
data = schema.dump(agent)
|
||||
return AgentSchema(**data)
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ def weather_tool_func():
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
location (str): The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
|
||||
977
tests/test_agent_files/max_messages.af
Normal file
977
tests/test_agent_files/max_messages.af
Normal file
File diff suppressed because one or more lines are too long
@@ -711,3 +711,51 @@ def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, oth
|
||||
agent_id=copied_agent_id,
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_with_max_steps(server, server_url, default_user, other_user):
|
||||
"""Test that max_steps parameter correctly limits messages by conversation steps."""
|
||||
# load agent from file with pre-populated messages
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "max_messages.af")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": ("max_messages.af", f, "application/json")}
|
||||
|
||||
form_data = {
|
||||
"append_copy_suffix": "false",
|
||||
"override_existing_tools": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{server_url}/v1/agents/import",
|
||||
headers={"user_id": default_user.id},
|
||||
files=files,
|
||||
data=form_data,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed to upload agent: {response.text}"
|
||||
agent_data = response.json()
|
||||
agent_id = agent_data["id"]
|
||||
|
||||
# test with default max_steps (should use None, returning all messages)
|
||||
full_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user)
|
||||
total_messages = len(full_result.messages)
|
||||
assert total_messages == 31, f"Expected 31 messages, got {total_messages}"
|
||||
|
||||
# test with max_steps=2 (should return messages from the last 2 user messages onward)
|
||||
limited_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user, max_steps=2)
|
||||
limited_user_count = sum(1 for msg in limited_result.messages if msg.role == "user")
|
||||
assert limited_user_count == 2, f"Expected 2 user messages (steps), got {limited_user_count}"
|
||||
assert len(limited_result.messages) == 2 * 3 + 1
|
||||
|
||||
# verify agent can still receive messages after being deserialized with limited steps
|
||||
agent_copy = server.agent_manager.deserialize(limited_result, actor=other_user, append_copy_suffix=True)
|
||||
response = server.send_messages(
|
||||
actor=other_user, agent_id=agent_copy.id, input_messages=[MessageCreate(role=MessageRole.user, content="Hello!")]
|
||||
)
|
||||
assert response is not None and response.step_count > 0, "Agent should be able to receive and respond to messages"
|
||||
|
||||
# test with max_steps=0 (should return only system message)
|
||||
empty_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user, max_steps=0)
|
||||
assert len(empty_result.messages) == 1, f"Expected 1 message (system), got {len(empty_result.messages)}"
|
||||
assert empty_result.messages[0].role == "system", "The only message should be the system message"
|
||||
|
||||
Reference in New Issue
Block a user