feat: structured outputs for send_message (#1764)

This commit is contained in:
Andy Li
2025-04-22 09:50:01 -07:00
committed by GitHub
parent 780d004b12
commit fa89ad859e
12 changed files with 413 additions and 17 deletions

View File

@@ -0,0 +1,31 @@
"""add support for structured_outputs in agents
Revision ID: 28b8765bdd0a
Revises: a3c7d62e08ca
Create Date: 2025-04-18 11:43:47.701786
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "28b8765bdd0a"
down_revision: Union[str, None] = "a3c7d62e08ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("agents", sa.Column("response_format", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("agents", "response_format")
# ### end Alembic commands ###

View File

@@ -17,6 +17,7 @@ from letta.constants import (
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LLM_MAX_TOKENS,
REQ_HEARTBEAT_MESSAGE,
SEND_MESSAGE_TOOL_NAME,
)
from letta.errors import ContextWindowExceededError
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
@@ -47,6 +48,7 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.response_format import ResponseFormatType
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
@@ -256,6 +258,28 @@ class Agent(BaseAgent):
# Return updated messages
return messages
def _runtime_override_tool_json_schema(
self,
functions_list: List[Dict | None],
) -> List[Dict | None]:
"""Override the tool JSON schema at runtime for a particular tool if conditions are met."""
# Currently just injects `send_message` with a `response_format` if provided to the agent.
if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text:
for func in functions_list:
if func["name"] == SEND_MESSAGE_TOOL_NAME:
if self.agent_state.response_format.type == ResponseFormatType.json_schema:
func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"]
if self.agent_state.response_format.type == ResponseFormatType.json_object:
func["parameters"]["properties"]["message"] = {
"type": "object",
"description": "Message contents. All unicode (including emojis) are supported.",
"additionalProperties": True,
"properties": {},
}
break
return functions_list
@trace_method
def _get_ai_reply(
self,
@@ -269,27 +293,26 @@ class Agent(BaseAgent):
step_count: Optional[int] = None,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
) -> ChatCompletionResponse:
) -> ChatCompletionResponse | None:
"""Get response from LLM API with robust retry mechanism."""
log_telemetry(self.logger, "_get_ai_reply start")
available_tools = set([t.name for t in self.agent_state.tools])
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
available_tools=available_tools, last_function_response=self.last_function_response
)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
allowed_functions = (
agent_state_tool_jsons
if not allowed_tool_names
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
)
# Get allowed tools or allow all if none are allowed
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
available_tools=available_tools, last_function_response=self.last_function_response
) or list(available_tools)
# Don't allow a tool to be called if it failed last time
if last_function_failed and self.tool_rules_solver.tool_call_history:
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]]
if not allowed_functions:
allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]]
if not allowed_tool_names:
return None
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
# For the first message, force the initial tool if one is specified
force_tool_call = None
if (
@@ -419,7 +442,7 @@ class Agent(BaseAgent):
tool_call_id = response_message.tool_calls[0].id
assert tool_call_id is not None # should be defined
# only necessary to add the tool_cal_id to a function call (antipattern)
# only necessary to add the tool_call_id to a function call (antipattern)
# response_message_dict = response_message.model_dump()
# response_message_dict["tool_call_id"] = tool_call_id
@@ -514,6 +537,10 @@ class Agent(BaseAgent):
# Failure case 3: function failed during execution
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
# this is because the function/tool role message is only created once the function/tool has executed/returned
# handle cases where we return a json message
if "message" in function_args:
function_args["message"] = str(function_args.get("message", ""))
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
try:

View File

@@ -32,6 +32,7 @@ from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.run import Run
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.source import Source, SourceCreate, SourceUpdate
@@ -100,6 +101,7 @@ class AbstractClient(object):
message_ids: Optional[List[str]] = None,
memory: Optional[Memory] = None,
tags: Optional[List[str]] = None,
response_format: Optional[ResponseFormatUnion] = None,
):
raise NotImplementedError
@@ -553,6 +555,7 @@ class RESTClient(AbstractClient):
initial_message_sequence: Optional[List[Message]] = None,
tags: Optional[List[str]] = None,
message_buffer_autoclear: bool = False,
response_format: Optional[ResponseFormatUnion] = None,
) -> AgentState:
"""Create an agent
@@ -615,6 +618,7 @@ class RESTClient(AbstractClient):
"include_base_tools": include_base_tools,
"message_buffer_autoclear": message_buffer_autoclear,
"include_multi_agent_tools": include_multi_agent_tools,
"response_format": response_format,
}
# Only add name if it's not None
@@ -653,6 +657,7 @@ class RESTClient(AbstractClient):
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
response_format: Optional[ResponseFormatUnion] = None,
) -> AgentState:
"""
Update an existing agent
@@ -682,6 +687,7 @@ class RESTClient(AbstractClient):
llm_config=llm_config,
embedding_config=embedding_config,
message_ids=message_ids,
response_format=response_format,
)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
@@ -2425,6 +2431,7 @@ class LocalClient(AbstractClient):
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
response_format: Optional[ResponseFormatUnion] = None,
):
"""
Update an existing agent
@@ -2458,6 +2465,7 @@ class LocalClient(AbstractClient):
llm_config=llm_config,
embedding_config=embedding_config,
message_ids=message_ids,
response_format=response_format,
),
actor=self.user,
)

View File

@@ -47,13 +47,14 @@ DEFAULT_PERSONA = "sam_pov"
DEFAULT_HUMAN = "basic"
DEFAULT_PRESET = "memgpt_chat"
SEND_MESSAGE_TOOL_NAME = "send_message"
# Base tools that cannot be edited, as they access agent state directly
# Note that we don't include "conversation_search_date" for now
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
BASE_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_insert", "archival_memory_search"]
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Base tools if the memgpt agent has enable_sleeptime on
BASE_SLEEPTIME_CHAT_TOOLS = ["send_message", "conversation_search", "archival_memory_search"]
BASE_SLEEPTIME_CHAT_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_search"]
# Base memory tools for sleeptime agent
BASE_SLEEPTIME_TOOLS = [
"memory_replace",
@@ -72,7 +73,7 @@ LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_S
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent)
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL = SEND_MESSAGE_TOOL_NAME
DEFAULT_MESSAGE_TOOL_KWARG = "message"
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"

View File

@@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import (
)
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import ToolReturn
from letta.schemas.response_format import (
JsonObjectResponseFormat,
JsonSchemaResponseFormat,
ResponseFormatType,
ResponseFormatUnion,
TextResponseFormat,
)
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
@@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
return None
return AgentStepState(**data)
# --------------------------
# Response Format Serialization
# --------------------------
def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]:
if not response_format:
return None
return response_format.model_dump(mode="json")
def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]:
if not data:
return None
if data["type"] == ResponseFormatType.text:
return TextResponseFormat(**data)
if data["type"] == ResponseFormatType.json_schema:
return JsonSchemaResponseFormat(**data)
if data["type"] == ResponseFormatType.json_object:
return JsonObjectResponseFormat(**data)

View File

@@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
from letta.orm.identity import Identity
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
@@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.tool_rule import ToolRule
if TYPE_CHECKING:
@@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# This is dangerously flexible with the JSON type
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
# Response Format
response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column(
ResponseFormatColumn, nullable=True, doc="The response format for the agent."
)
# Metadata and configs
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
@@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"multi_agent_group": None,
"tool_exec_environment_variables": [],
"enable_sleeptime": None,
"response_format": self.response_format,
}
# Optional fields: only included if requested

View File

@@ -9,6 +9,7 @@ from letta.helpers.converters import (
deserialize_llm_config,
deserialize_message_content,
deserialize_poll_batch_response,
deserialize_response_format,
deserialize_tool_calls,
deserialize_tool_returns,
deserialize_tool_rules,
@@ -20,6 +21,7 @@ from letta.helpers.converters import (
serialize_llm_config,
serialize_message_content,
serialize_poll_batch_response,
serialize_response_format,
serialize_tool_calls,
serialize_tool_returns,
serialize_tool_rules,
@@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator):
def process_result_value(self, value, dialect):
return deserialize_agent_step_state(value)
class ResponseFormatColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_response_format(value)
def process_result_value(self, value, dialect):
return deserialize_response_format(value)

View File

@@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ToolRule
@@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
response_format: Optional[ResponseFormatUnion] = Field(
None, description="The response format used by the agent when returning from `send_message`."
)
# This is an object representing the in-process state of a running `Agent`
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
@@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
)
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
@field_validator("name")
@classmethod
@@ -259,6 +264,7 @@ class UpdateAgent(BaseModel):
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
)
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
class Config:
extra = "ignore" # Ignores extra fields

View File

@@ -0,0 +1,78 @@
from enum import Enum
from typing import Annotated, Any, Dict, Literal, Union
from pydantic import BaseModel, Field, validator
class ResponseFormatType(str, Enum):
"""Enum defining the possible response format types."""
text = "text"
json_schema = "json_schema"
json_object = "json_object"
class ResponseFormat(BaseModel):
"""Base class for all response formats."""
type: ResponseFormatType = Field(
...,
description="The type of the response format.",
# why use this?
example=ResponseFormatType.text,
)
# ---------------------
# Response Format Types
# ---------------------
# SQLAlchemy type for database mapping
ResponseFormatDict = Dict[str, Any]
class TextResponseFormat(ResponseFormat):
"""Response format for plain text responses."""
type: Literal[ResponseFormatType.text] = Field(
ResponseFormatType.text,
description="The type of the response format.",
)
class JsonSchemaResponseFormat(ResponseFormat):
"""Response format for JSON schema-based responses."""
type: Literal[ResponseFormatType.json_schema] = Field(
ResponseFormatType.json_schema,
description="The type of the response format.",
)
json_schema: Dict[str, Any] = Field(
...,
description="The JSON schema of the response.",
)
@validator("json_schema")
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the provided schema is a valid JSON schema."""
if not isinstance(v, dict):
raise ValueError("JSON schema must be a dictionary")
if "schema" not in v:
raise ValueError("JSON schema should include a $schema property")
return v
class JsonObjectResponseFormat(ResponseFormat):
"""Response format for JSON object responses."""
type: Literal[ResponseFormatType.json_object] = Field(
ResponseFormatType.json_object,
description="The type of the response format.",
)
# Pydantic type for validation
ResponseFormatUnion = Annotated[
Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat],
Field(discriminator="type"),
]

View File

@@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
and function_call.function.name == self.assistant_message_tool_name
and self.assistant_message_tool_kwarg in func_args
):
# Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
content=func_args[self.assistant_message_tool_kwarg],
content=str(func_args[self.assistant_message_tool_kwarg]),
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)

View File

@@ -364,6 +364,7 @@ class AgentManager:
"base_template_id": agent_update.base_template_id,
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
"enable_sleeptime": agent_update.enable_sleeptime,
"response_format": agent_update.response_format,
}
for col, val in scalar_updates.items():
if val is not None:

View File

@@ -0,0 +1,192 @@
# TODO (cliandy): Tested in SDK
# TODO (cliandy): Comment out after merge
# import os
# import threading
# import time
# import pytest
# from dotenv import load_dotenv
# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool
# from letta.schemas.agent import AgentState
# from typing import List, Any, Dict
# # ------------------------------
# # Fixtures
# # ------------------------------
# @pytest.fixture(scope="module")
# def server_url() -> str:
# """
# Provides the URL for the Letta server.
# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
# will start the Letta server in a background thread and return the default URL.
# """
# def _run_server() -> None:
# """Starts the Letta server in a background thread."""
# load_dotenv() # Load environment variables from .env file
# from letta.server.rest_api.app import start_server
# start_server(debug=True)
# # Retrieve server URL from environment, or default to localhost
# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# # If no environment variable is set, start the server in a background thread
# if not os.getenv("LETTA_SERVER_URL"):
# thread = threading.Thread(target=_run_server, daemon=True)
# thread.start()
# time.sleep(5) # Allow time for the server to start
# return url
# @pytest.fixture
# def client(server_url: str) -> Letta:
# """
# Creates and returns a synchronous Letta REST client for testing.
# """
# client_instance = Letta(base_url=server_url)
# yield client_instance
# @pytest.fixture
# def async_client(server_url: str) -> AsyncLetta:
# """
# Creates and returns an asynchronous Letta REST client for testing.
# """
# async_client_instance = AsyncLetta(base_url=server_url)
# yield async_client_instance
# @pytest.fixture
# def roll_dice_tool(client: Letta) -> Tool:
# """
# Registers a simple roll dice tool with the provided client.
# The tool simulates rolling a six-sided die but returns a fixed result.
# """
# def roll_dice() -> str:
# """
# Simulates rolling a die.
# Returns:
# str: The roll result.
# """
# # Note: The result here is intentionally incorrect for demonstration purposes.
# return "Rolled a 10!"
# tool = client.tools.upsert_from_function(func=roll_dice)
# yield tool
# @pytest.fixture
# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
# """
# Creates and returns an agent state for testing with a pre-configured agent.
# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
# """
# agent_state_instance = client.agents.create(
# name="supervisor",
# include_base_tools=True,
# tool_ids=[roll_dice_tool.id],
# model="openai/gpt-4o",
# embedding="letta/letta-free",
# tags=["supervisor"],
# include_base_tool_rules=True,
# )
# yield agent_state_instance
# # Goal is to test that when an Agent is created with a `response_format`, that the response
# # of `send_message` is in the correct format. This will be done by modifying the agent's
# # `send_message` tool so that it returns a format based on what is passed in.
# #
# # `response_format` is an optional field
# # if `response_format.type` is `text`, then the schema does not change
# # if `response_format.type` is `json_object`, then the schema is a dict
# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema
# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}]
# # ------------------------------
# # Test Cases
# # ------------------------------
# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_object'."""
# client.agents.modify(agent.id, response_format={"type": "text"})
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, str)
# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_object'."""
# client.agents.modify(agent.id, response_format={"type": "json_object"})
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, dict)
# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_schema' and a valid schema."""
# client.agents.modify(agent.id, response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "reasoning_schema",
# "schema": {
# "type": "object",
# "properties": {
# "steps": {
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "explanation": { "type": "string" },
# "output": { "type": "string" }
# },
# "required": ["explanation", "output"],
# "additionalProperties": False
# }
# },
# "final_answer": { "type": "string" }
# },
# "required": ["steps", "final_answer"],
# "additionalProperties": True
# },
# "strict": True
# }
# })
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, dict)
# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None:
# # """Test client send_message with an invalid json_schema (should error or fallback)."""
# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}}
# # client.agents.modify(agent.id, response_format="json_schema")
# # result: Any = client.agents.send_message(agent.id, "Test invalid schema")
# # assert result is None or "error" in str(result).lower()