feat: Add max_steps parameter to agent export (#3828)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from marshmallow import fields, post_dump, pre_load
|
from marshmallow import fields, post_dump, pre_load
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
import letta
|
import letta
|
||||||
@@ -15,6 +16,7 @@ from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigFie
|
|||||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||||
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
||||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||||
|
from letta.settings import DatabaseChoice, settings
|
||||||
|
|
||||||
|
|
||||||
class MarshmallowAgentSchema(BaseSchema):
|
class MarshmallowAgentSchema(BaseSchema):
|
||||||
@@ -41,9 +43,10 @@ class MarshmallowAgentSchema(BaseSchema):
|
|||||||
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
||||||
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
|
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
|
||||||
|
|
||||||
def __init__(self, *args, session: sessionmaker, actor: User, **kwargs):
|
def __init__(self, *args, session: sessionmaker, actor: User, max_steps: Optional[int] = None, **kwargs):
|
||||||
super().__init__(*args, actor=actor, **kwargs)
|
super().__init__(*args, actor=actor, **kwargs)
|
||||||
self.session = session
|
self.session = session
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
# Propagate session and actor to nested schemas automatically
|
# Propagate session and actor to nested schemas automatically
|
||||||
for field in self.fields.values():
|
for field in self.fields.values():
|
||||||
@@ -64,16 +67,103 @@ class MarshmallowAgentSchema(BaseSchema):
|
|||||||
|
|
||||||
with db_registry.session() as session:
|
with db_registry.session() as session:
|
||||||
agent_id = data.get("id")
|
agent_id = data.get("id")
|
||||||
msgs = (
|
|
||||||
session.query(MessageModel)
|
if self.max_steps is not None:
|
||||||
.filter(
|
# first, always get the system message
|
||||||
MessageModel.agent_id == agent_id,
|
system_msg = (
|
||||||
MessageModel.organization_id == self.actor.organization_id,
|
session.query(MessageModel)
|
||||||
|
.filter(
|
||||||
|
MessageModel.agent_id == agent_id,
|
||||||
|
MessageModel.organization_id == self.actor.organization_id,
|
||||||
|
MessageModel.role == "system",
|
||||||
|
)
|
||||||
|
.order_by(MessageModel.sequence_id.asc())
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.order_by(MessageModel.sequence_id.asc())
|
|
||||||
.all()
|
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||||
)
|
# efficient PostgreSQL approach using subquery
|
||||||
# overwrite the “messages” key with a fully serialized list
|
user_msg_subquery = (
|
||||||
|
session.query(MessageModel.sequence_id)
|
||||||
|
.filter(
|
||||||
|
MessageModel.agent_id == agent_id,
|
||||||
|
MessageModel.organization_id == self.actor.organization_id,
|
||||||
|
MessageModel.role == "user",
|
||||||
|
)
|
||||||
|
.order_by(MessageModel.sequence_id.desc())
|
||||||
|
.limit(self.max_steps)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the minimum sequence_id from the subquery
|
||||||
|
cutoff_sequence_id = session.query(func.min(user_msg_subquery.c.sequence_id)).scalar()
|
||||||
|
|
||||||
|
if cutoff_sequence_id:
|
||||||
|
# get messages from cutoff, excluding system message to avoid duplicates
|
||||||
|
step_msgs = (
|
||||||
|
session.query(MessageModel)
|
||||||
|
.filter(
|
||||||
|
MessageModel.agent_id == agent_id,
|
||||||
|
MessageModel.organization_id == self.actor.organization_id,
|
||||||
|
MessageModel.sequence_id >= cutoff_sequence_id,
|
||||||
|
MessageModel.role != "system",
|
||||||
|
)
|
||||||
|
.order_by(MessageModel.sequence_id.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# combine system message with step messages
|
||||||
|
msgs = [system_msg] + step_msgs if system_msg else step_msgs
|
||||||
|
else:
|
||||||
|
# no user messages, just return system message
|
||||||
|
msgs = [system_msg] if system_msg else []
|
||||||
|
else:
|
||||||
|
# sqlite approach: get all user messages first, then get messages from cutoff
|
||||||
|
user_messages = (
|
||||||
|
session.query(MessageModel.sequence_id)
|
||||||
|
.filter(
|
||||||
|
MessageModel.agent_id == agent_id,
|
||||||
|
MessageModel.organization_id == self.actor.organization_id,
|
||||||
|
MessageModel.role == "user",
|
||||||
|
)
|
||||||
|
.order_by(MessageModel.sequence_id.desc())
|
||||||
|
.limit(self.max_steps)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_messages:
|
||||||
|
# get the minimum sequence_id
|
||||||
|
cutoff_sequence_id = min(msg.sequence_id for msg in user_messages)
|
||||||
|
|
||||||
|
# get messages from cutoff, excluding system message to avoid duplicates
|
||||||
|
step_msgs = (
|
||||||
|
session.query(MessageModel)
|
||||||
|
.filter(
|
||||||
|
MessageModel.agent_id == agent_id,
|
||||||
|
MessageModel.organization_id == self.actor.organization_id,
|
||||||
|
MessageModel.sequence_id >= cutoff_sequence_id,
|
||||||
|
MessageModel.role != "system",
|
||||||
|
)
|
||||||
|
.order_by(MessageModel.sequence_id.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# combine system message with step messages
|
||||||
|
msgs = [system_msg] + step_msgs if system_msg else step_msgs
|
||||||
|
else:
|
||||||
|
# no user messages, just return system message
|
||||||
|
msgs = [system_msg] if system_msg else []
|
||||||
|
else:
|
||||||
|
# if no limit, get all messages in ascending order
|
||||||
|
msgs = (
|
||||||
|
session.query(MessageModel)
|
||||||
|
.filter(
|
||||||
|
MessageModel.agent_id == agent_id,
|
||||||
|
MessageModel.organization_id == self.actor.organization_id,
|
||||||
|
)
|
||||||
|
.order_by(MessageModel.sequence_id.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite the "messages" key with a fully serialized list
|
||||||
data[self.FIELD_MESSAGES] = [SerializedMessageSchema(session=self.session, actor=self.actor).dump(m) for m in msgs]
|
data[self.FIELD_MESSAGES] = [SerializedMessageSchema(session=self.session, actor=self.actor).dump(m) for m in msgs]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ class IndentedORJSONResponse(Response):
|
|||||||
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
|
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
|
||||||
def export_agent_serialized(
|
def export_agent_serialized(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
max_steps: int = 100,
|
||||||
server: "SyncServer" = Depends(get_letta_server),
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
actor_id: str | None = Header(None, alias="user_id"),
|
actor_id: str | None = Header(None, alias="user_id"),
|
||||||
# do not remove, used to autogeneration of spec
|
# do not remove, used to autogeneration of spec
|
||||||
@@ -158,7 +159,7 @@ def export_agent_serialized(
|
|||||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
|
||||||
return agent.model_dump()
|
return agent.model_dump()
|
||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||||
|
|||||||
@@ -1446,10 +1446,10 @@ class AgentManager:
|
|||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
def serialize(self, agent_id: str, actor: PydanticUser, max_steps: Optional[int] = None) -> AgentSchema:
|
||||||
with db_registry.session() as session:
|
with db_registry.session() as session:
|
||||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
schema = MarshmallowAgentSchema(session=session, actor=actor, max_steps=max_steps)
|
||||||
data = schema.dump(agent)
|
data = schema.dump(agent)
|
||||||
return AgentSchema(**data)
|
return AgentSchema(**data)
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ def weather_tool_func():
|
|||||||
"""
|
"""
|
||||||
Fetches the current weather for a given location.
|
Fetches the current weather for a given location.
|
||||||
|
|
||||||
Parameters:
|
Args:
|
||||||
location (str): The location to get the weather for.
|
location (str): The location to get the weather for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
977
tests/test_agent_files/max_messages.af
Normal file
977
tests/test_agent_files/max_messages.af
Normal file
File diff suppressed because one or more lines are too long
@@ -711,3 +711,51 @@ def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, oth
|
|||||||
agent_id=copied_agent_id,
|
agent_id=copied_agent_id,
|
||||||
input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
|
input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_with_max_steps(server, server_url, default_user, other_user):
|
||||||
|
"""Test that max_steps parameter correctly limits messages by conversation steps."""
|
||||||
|
# load agent from file with pre-populated messages
|
||||||
|
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "max_messages.af")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
files = {"file": ("max_messages.af", f, "application/json")}
|
||||||
|
|
||||||
|
form_data = {
|
||||||
|
"append_copy_suffix": "false",
|
||||||
|
"override_existing_tools": "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{server_url}/v1/agents/import",
|
||||||
|
headers={"user_id": default_user.id},
|
||||||
|
files=files,
|
||||||
|
data=form_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed to upload agent: {response.text}"
|
||||||
|
agent_data = response.json()
|
||||||
|
agent_id = agent_data["id"]
|
||||||
|
|
||||||
|
# test with default max_steps (should use None, returning all messages)
|
||||||
|
full_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user)
|
||||||
|
total_messages = len(full_result.messages)
|
||||||
|
assert total_messages == 31, f"Expected 31 messages, got {total_messages}"
|
||||||
|
|
||||||
|
# test with max_steps=2 (should return messages from the last 2 user messages onward)
|
||||||
|
limited_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user, max_steps=2)
|
||||||
|
limited_user_count = sum(1 for msg in limited_result.messages if msg.role == "user")
|
||||||
|
assert limited_user_count == 2, f"Expected 2 user messages (steps), got {limited_user_count}"
|
||||||
|
assert len(limited_result.messages) == 2 * 3 + 1
|
||||||
|
|
||||||
|
# verify agent can still receive messages after being deserialized with limited steps
|
||||||
|
agent_copy = server.agent_manager.deserialize(limited_result, actor=other_user, append_copy_suffix=True)
|
||||||
|
response = server.send_messages(
|
||||||
|
actor=other_user, agent_id=agent_copy.id, input_messages=[MessageCreate(role=MessageRole.user, content="Hello!")]
|
||||||
|
)
|
||||||
|
assert response is not None and response.step_count > 0, "Agent should be able to receive and respond to messages"
|
||||||
|
|
||||||
|
# test with max_steps=0 (should return only system message)
|
||||||
|
empty_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user, max_steps=0)
|
||||||
|
assert len(empty_result.messages) == 1, f"Expected 1 message (system), got {len(empty_result.messages)}"
|
||||||
|
assert empty_result.messages[0].role == "system", "The only message should be the system message"
|
||||||
|
|||||||
Reference in New Issue
Block a user