diff --git a/alembic/versions/28b8765bdd0a_add_support_for_structured_outputs_in_.py b/alembic/versions/28b8765bdd0a_add_support_for_structured_outputs_in_.py new file mode 100644 index 00000000..e5bcab08 --- /dev/null +++ b/alembic/versions/28b8765bdd0a_add_support_for_structured_outputs_in_.py @@ -0,0 +1,31 @@ +"""add support for structured_outputs in agents + +Revision ID: 28b8765bdd0a +Revises: a3c7d62e08ca +Create Date: 2025-04-18 11:43:47.701786 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "28b8765bdd0a" +down_revision: Union[str, None] = "a3c7d62e08ca" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("agents", sa.Column("response_format", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("agents", "response_format") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 7dba5ead..8ec8e204 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -17,6 +17,7 @@ from letta.constants import ( LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LLM_MAX_TOKENS, REQ_HEARTBEAT_MESSAGE, + SEND_MESSAGE_TOOL_NAME, ) from letta.errors import ContextWindowExceededError from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source @@ -47,6 +48,7 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.response_format import ResponseFormatType from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule @@ -256,6 +258,28 @@ class Agent(BaseAgent): # Return updated messages return messages + def _runtime_override_tool_json_schema( + self, + functions_list: List[Dict | None], + ) -> List[Dict | None]: + """Override the tool JSON schema at runtime for a particular tool if conditions are met.""" + + # Currently just injects `send_message` with a `response_format` if provided to the agent. + if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text: + for func in functions_list: + if func["name"] == SEND_MESSAGE_TOOL_NAME: + if self.agent_state.response_format.type == ResponseFormatType.json_schema: + func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"] + if self.agent_state.response_format.type == ResponseFormatType.json_object: + func["parameters"]["properties"]["message"] = { + "type": "object", + "description": "Message contents. All unicode (including emojis) are supported.", + "additionalProperties": True, + "properties": {}, + } + break + return functions_list + @trace_method def _get_ai_reply( self, @@ -269,27 +293,26 @@ class Agent(BaseAgent): step_count: Optional[int] = None, last_function_failed: bool = False, put_inner_thoughts_first: bool = True, - ) -> ChatCompletionResponse: + ) -> ChatCompletionResponse | None: """Get response from LLM API with robust retry mechanism.""" log_telemetry(self.logger, "_get_ai_reply start") available_tools = set([t.name for t in self.agent_state.tools]) - allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names( - available_tools=available_tools, last_function_response=self.last_function_response - ) agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] - allowed_functions = ( - agent_state_tool_jsons - if not allowed_tool_names - else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] - ) + # Get allowed tools or allow all if none are allowed + allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names( + available_tools=available_tools, last_function_response=self.last_function_response + ) or list(available_tools) # Don't allow a tool to be called if it failed last time if last_function_failed and self.tool_rules_solver.tool_call_history: - allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]] - if not allowed_functions: + allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]] + if not allowed_tool_names: return None + allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] + allowed_functions = self._runtime_override_tool_json_schema(allowed_functions) + # For the first message, force the initial tool if one is specified force_tool_call = None if ( @@ -419,7 +442,7 @@ class Agent(BaseAgent): tool_call_id = response_message.tool_calls[0].id assert tool_call_id is not None # should be defined - # only necessary to add the tool_cal_id to a function call (antipattern) + # only necessary to add the tool_call_id to a function call (antipattern) # response_message_dict = response_message.model_dump() # response_message_dict["tool_call_id"] = tool_call_id @@ -514,6 +537,10 @@ class Agent(BaseAgent): # Failure case 3: function failed during execution # NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message # this is because the function/tool role message is only created once the function/tool has executed/returned + + # handle cases where we return a json message + if "message" in function_args: + function_args["message"] = str(function_args.get("message", "")) self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index) self.chunk_index += 1 try: diff --git a/letta/client/client.py b/letta/client/client.py index 7effa659..04572a6b 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -32,6 +32,7 @@ from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.organization import Organization from letta.schemas.passage import Passage +from letta.schemas.response_format import ResponseFormatUnion from letta.schemas.run import Run from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.source import Source, SourceCreate, SourceUpdate @@ -100,6 +101,7 @@ class AbstractClient(object): message_ids: Optional[List[str]] = None, memory: Optional[Memory] = None, tags: Optional[List[str]] = None, + response_format: Optional[ResponseFormatUnion] = None, ): raise NotImplementedError @@ -553,6 +555,7 @@ class RESTClient(AbstractClient): initial_message_sequence: Optional[List[Message]] = None, tags: Optional[List[str]] = None, message_buffer_autoclear: bool = False, + response_format: Optional[ResponseFormatUnion] = None, ) -> AgentState: """Create an agent @@ -615,6 +618,7 @@ class RESTClient(AbstractClient): "include_base_tools": include_base_tools, "message_buffer_autoclear": message_buffer_autoclear, "include_multi_agent_tools": include_multi_agent_tools, + "response_format": response_format, } # Only add name if it's not None @@ -653,6 +657,7 @@ class RESTClient(AbstractClient): embedding_config: Optional[EmbeddingConfig] = None, message_ids: Optional[List[str]] = None, tags: Optional[List[str]] = None, + response_format: Optional[ResponseFormatUnion] = None, ) -> AgentState: """ Update an existing agent @@ -682,6 +687,7 @@ class RESTClient(AbstractClient): llm_config=llm_config, embedding_config=embedding_config, message_ids=message_ids, + response_format=response_format, ) response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: @@ -2425,6 +2431,7 @@ class LocalClient(AbstractClient): llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, message_ids: Optional[List[str]] = None, + response_format: Optional[ResponseFormatUnion] = None, ): """ Update an existing agent @@ -2458,6 +2465,7 @@ class LocalClient(AbstractClient): llm_config=llm_config, embedding_config=embedding_config, message_ids=message_ids, + response_format=response_format, ), actor=self.user, ) diff --git a/letta/constants.py b/letta/constants.py index b8913503..64e49d07 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -47,13 +47,14 @@ DEFAULT_PERSONA = "sam_pov" DEFAULT_HUMAN = "basic" DEFAULT_PRESET = "memgpt_chat" +SEND_MESSAGE_TOOL_NAME = "send_message" # Base tools that cannot be edited, as they access agent state directly # Note that we don't include "conversation_search_date" for now -BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"] +BASE_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_insert", "archival_memory_search"] # Base memory tools CAN be edited, and are added by default by the server BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] # Base tools if the memgpt agent has enable_sleeptime on -BASE_SLEEPTIME_CHAT_TOOLS = ["send_message", "conversation_search", "archival_memory_search"] +BASE_SLEEPTIME_CHAT_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_search"] # Base memory tools for sleeptime agent BASE_SLEEPTIME_TOOLS = [ "memory_replace", @@ -72,7 +73,7 @@ LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_S # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) # or in cases where the agent has no concept of messaging a user (e.g. a workflow agent) -DEFAULT_MESSAGE_TOOL = "send_message" +DEFAULT_MESSAGE_TOOL = SEND_MESSAGE_TOOL_NAME DEFAULT_MESSAGE_TOOL_KWARG = "message" PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg" diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 88dacef9..45a45b6d 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import ( ) from letta.schemas.llm_config import LLMConfig from letta.schemas.message import ToolReturn +from letta.schemas.response_format import ( + JsonObjectResponseFormat, + JsonSchemaResponseFormat, + ResponseFormatType, + ResponseFormatUnion, + TextResponseFormat, +) from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, @@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat return None return AgentStepState(**data) + + +# -------------------------- +# Response Format Serialization +# -------------------------- + + +def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]: + if not response_format: + return None + return response_format.model_dump(mode="json") + + +def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]: + if not data: + return None + if data["type"] == ResponseFormatType.text: + return TextResponseFormat(**data) + if data["type"] == ResponseFormatType.json_schema: + return JsonSchemaResponseFormat(**data) + if data["type"] == ResponseFormatType.json_object: + return JsonObjectResponseFormat(**data) diff --git a/letta/orm/agent.py b/letta/orm/agent.py index fe5d7a74..ed5deb5a 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block -from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn +from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn from letta.orm.identity import Identity from letta.orm.mixins import OrganizationMixin from letta.orm.organization import Organization @@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory +from letta.schemas.response_format import ResponseFormatUnion from letta.schemas.tool_rule import ToolRule if TYPE_CHECKING: @@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin): # This is dangerously flexible with the JSON type message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.") + # Response Format + response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column( + ResponseFormatColumn, nullable=True, doc="The response format for the agent." + ) + # Metadata and configs metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.") llm_config: Mapped[Optional[LLMConfig]] = mapped_column( @@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): "multi_agent_group": None, "tool_exec_environment_variables": [], "enable_sleeptime": None, + "response_format": self.response_format, } # Optional fields: only included if requested diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index 77346406..5bc1c7dc 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -9,6 +9,7 @@ from letta.helpers.converters import ( deserialize_llm_config, deserialize_message_content, deserialize_poll_batch_response, + deserialize_response_format, deserialize_tool_calls, deserialize_tool_returns, deserialize_tool_rules, @@ -20,6 +21,7 @@ from letta.helpers.converters import ( serialize_llm_config, serialize_message_content, serialize_poll_batch_response, + serialize_response_format, serialize_tool_calls, serialize_tool_returns, serialize_tool_rules, @@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator): def process_result_value(self, value, dialect): return deserialize_agent_step_state(value) + + +class ResponseFormatColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing a list of ToolRules as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_response_format(value) + + def process_result_value(self, value, dialect): + return deserialize_response_format(value) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 6d30afcd..b6b84183 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.response_format import ResponseFormatUnion from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.tool_rule import ToolRule @@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True): # llm information llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.") + response_format: Optional[ResponseFormatUnion] = Field( + None, description="The response format used by the agent when returning from `send_message`." + ) # This is an object representing the in-process state of a running `Agent` # Field in this object can be theoretically edited by tools, and will be persisted by the ORM @@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): # description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", ) enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.") + response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.") @field_validator("name") @classmethod @@ -259,6 +264,7 @@ class UpdateAgent(BaseModel): None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name." ) enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.") + response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.") class Config: extra = "ignore" # Ignores extra fields diff --git a/letta/schemas/response_format.py b/letta/schemas/response_format.py new file mode 100644 index 00000000..08197c57 --- /dev/null +++ b/letta/schemas/response_format.py @@ -0,0 +1,78 @@ +from enum import Enum +from typing import Annotated, Any, Dict, Literal, Union + +from pydantic import BaseModel, Field, validator + + +class ResponseFormatType(str, Enum): + """Enum defining the possible response format types.""" + + text = "text" + json_schema = "json_schema" + json_object = "json_object" + + +class ResponseFormat(BaseModel): + """Base class for all response formats.""" + + type: ResponseFormatType = Field( + ..., + description="The type of the response format.", + # why use this? + example=ResponseFormatType.text, + ) + + +# --------------------- +# Response Format Types +# --------------------- + +# SQLAlchemy type for database mapping +ResponseFormatDict = Dict[str, Any] + + +class TextResponseFormat(ResponseFormat): + """Response format for plain text responses.""" + + type: Literal[ResponseFormatType.text] = Field( + ResponseFormatType.text, + description="The type of the response format.", + ) + + +class JsonSchemaResponseFormat(ResponseFormat): + """Response format for JSON schema-based responses.""" + + type: Literal[ResponseFormatType.json_schema] = Field( + ResponseFormatType.json_schema, + description="The type of the response format.", + ) + json_schema: Dict[str, Any] = Field( + ..., + description="The JSON schema of the response.", + ) + + @validator("json_schema") + def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]: + """Validate that the provided schema is a valid JSON schema.""" + if not isinstance(v, dict): + raise ValueError("JSON schema must be a dictionary") + if "schema" not in v: + raise ValueError("JSON schema should include a $schema property") + return v + + +class JsonObjectResponseFormat(ResponseFormat): + """Response format for JSON object responses.""" + + type: Literal[ResponseFormatType.json_object] = Field( + ResponseFormatType.json_object, + description="The type of the response format.", + ) + + +# Pydantic type for validation +ResponseFormatUnion = Annotated[ + Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat], + Field(discriminator="type"), +] diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index c5508f61..469ff0a2 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface): and function_call.function.name == self.assistant_message_tool_name and self.assistant_message_tool_kwarg in func_args ): + # Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON processed_chunk = AssistantMessage( id=msg_obj.id, date=msg_obj.created_at, - content=func_args[self.assistant_message_tool_kwarg], + content=str(func_args[self.assistant_message_tool_kwarg]), name=msg_obj.name, otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 00cf8c4c..a1dcdb8e 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -364,6 +364,7 @@ class AgentManager: "base_template_id": agent_update.base_template_id, "message_buffer_autoclear": agent_update.message_buffer_autoclear, "enable_sleeptime": agent_update.enable_sleeptime, + "response_format": agent_update.response_format, } for col, val in scalar_updates.items(): if val is not None: diff --git a/tests/integration_test_send_message_schema.py b/tests/integration_test_send_message_schema.py new file mode 100644 index 00000000..57773ec8 --- /dev/null +++ b/tests/integration_test_send_message_schema.py @@ -0,0 +1,192 @@ +# TODO (cliandy): Tested in SDK +# TODO (cliandy): Comment out after merge + +# import os +# import threading +# import time + +# import pytest +# from dotenv import load_dotenv +# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool + +# from letta.schemas.agent import AgentState +# from typing import List, Any, Dict + +# # ------------------------------ +# # Fixtures +# # ------------------------------ + + +# @pytest.fixture(scope="module") +# def server_url() -> str: +# """ +# Provides the URL for the Letta server. +# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture +# will start the Letta server in a background thread and return the default URL. +# """ + +# def _run_server() -> None: +# """Starts the Letta server in a background thread.""" +# load_dotenv() # Load environment variables from .env file +# from letta.server.rest_api.app import start_server + +# start_server(debug=True) + +# # Retrieve server URL from environment, or default to localhost +# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + +# # If no environment variable is set, start the server in a background thread +# if not os.getenv("LETTA_SERVER_URL"): +# thread = threading.Thread(target=_run_server, daemon=True) +# thread.start() +# time.sleep(5) # Allow time for the server to start + +# return url + + +# @pytest.fixture +# def client(server_url: str) -> Letta: +# """ +# Creates and returns a synchronous Letta REST client for testing. +# """ +# client_instance = Letta(base_url=server_url) +# yield client_instance + + +# @pytest.fixture +# def async_client(server_url: str) -> AsyncLetta: +# """ +# Creates and returns an asynchronous Letta REST client for testing. +# """ +# async_client_instance = AsyncLetta(base_url=server_url) +# yield async_client_instance + + +# @pytest.fixture +# def roll_dice_tool(client: Letta) -> Tool: +# """ +# Registers a simple roll dice tool with the provided client. + +# The tool simulates rolling a six-sided die but returns a fixed result. +# """ + +# def roll_dice() -> str: +# """ +# Simulates rolling a die. + +# Returns: +# str: The roll result. +# """ +# # Note: The result here is intentionally incorrect for demonstration purposes. +# return "Rolled a 10!" + +# tool = client.tools.upsert_from_function(func=roll_dice) +# yield tool + + +# @pytest.fixture +# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: +# """ +# Creates and returns an agent state for testing with a pre-configured agent. +# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. +# """ +# agent_state_instance = client.agents.create( +# name="supervisor", +# include_base_tools=True, +# tool_ids=[roll_dice_tool.id], +# model="openai/gpt-4o", +# embedding="letta/letta-free", +# tags=["supervisor"], +# include_base_tool_rules=True, + +# ) +# yield agent_state_instance + + +# # Goal is to test that when an Agent is created with a `response_format`, that the response +# # of `send_message` is in the correct format. This will be done by modifying the agent's +# # `send_message` tool so that it returns a format based on what is passed in. +# # +# # `response_format` is an optional field +# # if `response_format.type` is `text`, then the schema does not change +# # if `response_format.type` is `json_object`, then the schema is a dict +# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema + + +# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}] + +# # ------------------------------ +# # Test Cases +# # ------------------------------ + +# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None: +# """Test client send_message with response_format='json_object'.""" +# client.agents.modify(agent.id, response_format={"type": "text"}) + +# response = client.agents.messages.create_stream( +# agent_id=agent.id, +# messages=USER_MESSAGE, +# ) +# messages = list(response) +# assert isinstance(messages[-1], AssistantMessage) +# assert isinstance(messages[-1].content, str) + + +# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None: +# """Test client send_message with response_format='json_object'.""" +# client.agents.modify(agent.id, response_format={"type": "json_object"}) + +# response = client.agents.messages.create_stream( +# agent_id=agent.id, +# messages=USER_MESSAGE, +# ) +# messages = list(response) +# assert isinstance(messages[-1], AssistantMessage) +# assert isinstance(messages[-1].content, dict) + + +# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None: +# """Test client send_message with response_format='json_schema' and a valid schema.""" +# client.agents.modify(agent.id, response_format={ +# "type": "json_schema", +# "json_schema": { +# "name": "reasoning_schema", +# "schema": { +# "type": "object", +# "properties": { +# "steps": { +# "type": "array", +# "items": { +# "type": "object", +# "properties": { +# "explanation": { "type": "string" }, +# "output": { "type": "string" } +# }, +# "required": ["explanation", "output"], +# "additionalProperties": False +# } +# }, +# "final_answer": { "type": "string" } +# }, +# "required": ["steps", "final_answer"], +# "additionalProperties": True +# }, +# "strict": True +# } +# }) +# response = client.agents.messages.create_stream( +# agent_id=agent.id, +# messages=USER_MESSAGE, +# ) +# messages = list(response) + +# assert isinstance(messages[-1], AssistantMessage) +# assert isinstance(messages[-1].content, dict) + + +# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None: +# # """Test client send_message with an invalid json_schema (should error or fallback).""" +# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}} +# # client.agents.modify(agent.id, response_format="json_schema") +# # result: Any = client.agents.send_message(agent.id, "Test invalid schema") +# # assert result is None or "error" in str(result).lower()