feat: structured outputs for send_message (#1764)
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""add support for structured_outputs in agents
|
||||
|
||||
Revision ID: 28b8765bdd0a
|
||||
Revises: a3c7d62e08ca
|
||||
Create Date: 2025-04-18 11:43:47.701786
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "28b8765bdd0a"
|
||||
down_revision: Union[str, None] = "a3c7d62e08ca"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("agents", sa.Column("response_format", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("agents", "response_format")
|
||||
# ### end Alembic commands ###
|
||||
@@ -17,6 +17,7 @@ from letta.constants import (
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
LLM_MAX_TOKENS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
SEND_MESSAGE_TOOL_NAME,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
@@ -47,6 +48,7 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.response_format import ResponseFormatType
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
@@ -256,6 +258,28 @@ class Agent(BaseAgent):
|
||||
# Return updated messages
|
||||
return messages
|
||||
|
||||
def _runtime_override_tool_json_schema(
|
||||
self,
|
||||
functions_list: List[Dict | None],
|
||||
) -> List[Dict | None]:
|
||||
"""Override the tool JSON schema at runtime for a particular tool if conditions are met."""
|
||||
|
||||
# Currently just injects `send_message` with a `response_format` if provided to the agent.
|
||||
if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text:
|
||||
for func in functions_list:
|
||||
if func["name"] == SEND_MESSAGE_TOOL_NAME:
|
||||
if self.agent_state.response_format.type == ResponseFormatType.json_schema:
|
||||
func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"]
|
||||
if self.agent_state.response_format.type == ResponseFormatType.json_object:
|
||||
func["parameters"]["properties"]["message"] = {
|
||||
"type": "object",
|
||||
"description": "Message contents. All unicode (including emojis) are supported.",
|
||||
"additionalProperties": True,
|
||||
"properties": {},
|
||||
}
|
||||
break
|
||||
return functions_list
|
||||
|
||||
@trace_method
|
||||
def _get_ai_reply(
|
||||
self,
|
||||
@@ -269,27 +293,26 @@ class Agent(BaseAgent):
|
||||
step_count: Optional[int] = None,
|
||||
last_function_failed: bool = False,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> ChatCompletionResponse:
|
||||
) -> ChatCompletionResponse | None:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
log_telemetry(self.logger, "_get_ai_reply start")
|
||||
available_tools = set([t.name for t in self.agent_state.tools])
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
||||
available_tools=available_tools, last_function_response=self.last_function_response
|
||||
)
|
||||
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
||||
|
||||
allowed_functions = (
|
||||
agent_state_tool_jsons
|
||||
if not allowed_tool_names
|
||||
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
)
|
||||
# Get allowed tools or allow all if none are allowed
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
||||
available_tools=available_tools, last_function_response=self.last_function_response
|
||||
) or list(available_tools)
|
||||
|
||||
# Don't allow a tool to be called if it failed last time
|
||||
if last_function_failed and self.tool_rules_solver.tool_call_history:
|
||||
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]]
|
||||
if not allowed_functions:
|
||||
allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]]
|
||||
if not allowed_tool_names:
|
||||
return None
|
||||
|
||||
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
|
||||
|
||||
# For the first message, force the initial tool if one is specified
|
||||
force_tool_call = None
|
||||
if (
|
||||
@@ -419,7 +442,7 @@ class Agent(BaseAgent):
|
||||
tool_call_id = response_message.tool_calls[0].id
|
||||
assert tool_call_id is not None # should be defined
|
||||
|
||||
# only necessary to add the tool_cal_id to a function call (antipattern)
|
||||
# only necessary to add the tool_call_id to a function call (antipattern)
|
||||
# response_message_dict = response_message.model_dump()
|
||||
# response_message_dict["tool_call_id"] = tool_call_id
|
||||
|
||||
@@ -514,6 +537,10 @@ class Agent(BaseAgent):
|
||||
# Failure case 3: function failed during execution
|
||||
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
|
||||
# this is because the function/tool role message is only created once the function/tool has executed/returned
|
||||
|
||||
# handle cases where we return a json message
|
||||
if "message" in function_args:
|
||||
function_args["message"] = str(function_args.get("message", ""))
|
||||
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
try:
|
||||
|
||||
@@ -32,6 +32,7 @@ from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
@@ -100,6 +101,7 @@ class AbstractClient(object):
|
||||
message_ids: Optional[List[str]] = None,
|
||||
memory: Optional[Memory] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -553,6 +555,7 @@ class RESTClient(AbstractClient):
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
message_buffer_autoclear: bool = False,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
) -> AgentState:
|
||||
"""Create an agent
|
||||
|
||||
@@ -615,6 +618,7 @@ class RESTClient(AbstractClient):
|
||||
"include_base_tools": include_base_tools,
|
||||
"message_buffer_autoclear": message_buffer_autoclear,
|
||||
"include_multi_agent_tools": include_multi_agent_tools,
|
||||
"response_format": response_format,
|
||||
}
|
||||
|
||||
# Only add name if it's not None
|
||||
@@ -653,6 +657,7 @@ class RESTClient(AbstractClient):
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
) -> AgentState:
|
||||
"""
|
||||
Update an existing agent
|
||||
@@ -682,6 +687,7 @@ class RESTClient(AbstractClient):
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
message_ids=message_ids,
|
||||
response_format=response_format,
|
||||
)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
@@ -2425,6 +2431,7 @@ class LocalClient(AbstractClient):
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
):
|
||||
"""
|
||||
Update an existing agent
|
||||
@@ -2458,6 +2465,7 @@ class LocalClient(AbstractClient):
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
message_ids=message_ids,
|
||||
response_format=response_format,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
@@ -47,13 +47,14 @@ DEFAULT_PERSONA = "sam_pov"
|
||||
DEFAULT_HUMAN = "basic"
|
||||
DEFAULT_PRESET = "memgpt_chat"
|
||||
|
||||
SEND_MESSAGE_TOOL_NAME = "send_message"
|
||||
# Base tools that cannot be edited, as they access agent state directly
|
||||
# Note that we don't include "conversation_search_date" for now
|
||||
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
|
||||
BASE_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_insert", "archival_memory_search"]
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
# Base tools if the memgpt agent has enable_sleeptime on
|
||||
BASE_SLEEPTIME_CHAT_TOOLS = ["send_message", "conversation_search", "archival_memory_search"]
|
||||
BASE_SLEEPTIME_CHAT_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_search"]
|
||||
# Base memory tools for sleeptime agent
|
||||
BASE_SLEEPTIME_TOOLS = [
|
||||
"memory_replace",
|
||||
@@ -72,7 +73,7 @@ LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_S
|
||||
# The name of the tool used to send message to the user
|
||||
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
|
||||
# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent)
|
||||
DEFAULT_MESSAGE_TOOL = "send_message"
|
||||
DEFAULT_MESSAGE_TOOL = SEND_MESSAGE_TOOL_NAME
|
||||
DEFAULT_MESSAGE_TOOL_KWARG = "message"
|
||||
|
||||
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
|
||||
|
||||
@@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import (
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.schemas.response_format import (
|
||||
JsonObjectResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormatType,
|
||||
ResponseFormatUnion,
|
||||
TextResponseFormat,
|
||||
)
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
@@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
|
||||
return None
|
||||
|
||||
return AgentStepState(**data)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Response Format Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]:
|
||||
if not response_format:
|
||||
return None
|
||||
return response_format.model_dump(mode="json")
|
||||
|
||||
|
||||
def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]:
|
||||
if not data:
|
||||
return None
|
||||
if data["type"] == ResponseFormatType.text:
|
||||
return TextResponseFormat(**data)
|
||||
if data["type"] == ResponseFormatType.json_schema:
|
||||
return JsonSchemaResponseFormat(**data)
|
||||
if data["type"] == ResponseFormatType.json_object:
|
||||
return JsonObjectResponseFormat(**data)
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.organization import Organization
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
# This is dangerously flexible with the JSON type
|
||||
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
|
||||
|
||||
# Response Format
|
||||
response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column(
|
||||
ResponseFormatColumn, nullable=True, doc="The response format for the agent."
|
||||
)
|
||||
|
||||
# Metadata and configs
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
|
||||
@@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"multi_agent_group": None,
|
||||
"tool_exec_environment_variables": [],
|
||||
"enable_sleeptime": None,
|
||||
"response_format": self.response_format,
|
||||
}
|
||||
|
||||
# Optional fields: only included if requested
|
||||
|
||||
@@ -9,6 +9,7 @@ from letta.helpers.converters import (
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_poll_batch_response,
|
||||
deserialize_response_format,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_returns,
|
||||
deserialize_tool_rules,
|
||||
@@ -20,6 +21,7 @@ from letta.helpers.converters import (
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_poll_batch_response,
|
||||
serialize_response_format,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_returns,
|
||||
serialize_tool_rules,
|
||||
@@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator):
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_agent_step_state(value)
|
||||
|
||||
|
||||
class ResponseFormatColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_response_format(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_response_format(value)
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
@@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
# llm information
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(
|
||||
None, description="The response format used by the agent when returning from `send_message`."
|
||||
)
|
||||
|
||||
# This is an object representing the in-process state of a running `Agent`
|
||||
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
|
||||
@@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
)
|
||||
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@@ -259,6 +264,7 @@ class UpdateAgent(BaseModel):
|
||||
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
)
|
||||
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
||||
78
letta/schemas/response_format.py
Normal file
78
letta/schemas/response_format.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class ResponseFormatType(str, Enum):
|
||||
"""Enum defining the possible response format types."""
|
||||
|
||||
text = "text"
|
||||
json_schema = "json_schema"
|
||||
json_object = "json_object"
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""Base class for all response formats."""
|
||||
|
||||
type: ResponseFormatType = Field(
|
||||
...,
|
||||
description="The type of the response format.",
|
||||
# why use this?
|
||||
example=ResponseFormatType.text,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Response Format Types
|
||||
# ---------------------
|
||||
|
||||
# SQLAlchemy type for database mapping
|
||||
ResponseFormatDict = Dict[str, Any]
|
||||
|
||||
|
||||
class TextResponseFormat(ResponseFormat):
|
||||
"""Response format for plain text responses."""
|
||||
|
||||
type: Literal[ResponseFormatType.text] = Field(
|
||||
ResponseFormatType.text,
|
||||
description="The type of the response format.",
|
||||
)
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(ResponseFormat):
|
||||
"""Response format for JSON schema-based responses."""
|
||||
|
||||
type: Literal[ResponseFormatType.json_schema] = Field(
|
||||
ResponseFormatType.json_schema,
|
||||
description="The type of the response format.",
|
||||
)
|
||||
json_schema: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="The JSON schema of the response.",
|
||||
)
|
||||
|
||||
@validator("json_schema")
|
||||
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that the provided schema is a valid JSON schema."""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("JSON schema must be a dictionary")
|
||||
if "schema" not in v:
|
||||
raise ValueError("JSON schema should include a $schema property")
|
||||
return v
|
||||
|
||||
|
||||
class JsonObjectResponseFormat(ResponseFormat):
|
||||
"""Response format for JSON object responses."""
|
||||
|
||||
type: Literal[ResponseFormatType.json_object] = Field(
|
||||
ResponseFormatType.json_object,
|
||||
description="The type of the response format.",
|
||||
)
|
||||
|
||||
|
||||
# Pydantic type for validation
|
||||
ResponseFormatUnion = Annotated[
|
||||
Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
@@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
and function_call.function.name == self.assistant_message_tool_name
|
||||
and self.assistant_message_tool_kwarg in func_args
|
||||
):
|
||||
# Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON
|
||||
processed_chunk = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
content=func_args[self.assistant_message_tool_kwarg],
|
||||
content=str(func_args[self.assistant_message_tool_kwarg]),
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
||||
@@ -364,6 +364,7 @@ class AgentManager:
|
||||
"base_template_id": agent_update.base_template_id,
|
||||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||||
"response_format": agent_update.response_format,
|
||||
}
|
||||
for col, val in scalar_updates.items():
|
||||
if val is not None:
|
||||
|
||||
192
tests/integration_test_send_message_schema.py
Normal file
192
tests/integration_test_send_message_schema.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# TODO (cliandy): Tested in SDK
|
||||
# TODO (cliandy): Comment out after merge
|
||||
|
||||
# import os
|
||||
# import threading
|
||||
# import time
|
||||
|
||||
# import pytest
|
||||
# from dotenv import load_dotenv
|
||||
# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool
|
||||
|
||||
# from letta.schemas.agent import AgentState
|
||||
# from typing import List, Any, Dict
|
||||
|
||||
# # ------------------------------
|
||||
# # Fixtures
|
||||
# # ------------------------------
|
||||
|
||||
|
||||
# @pytest.fixture(scope="module")
|
||||
# def server_url() -> str:
|
||||
# """
|
||||
# Provides the URL for the Letta server.
|
||||
# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
|
||||
# will start the Letta server in a background thread and return the default URL.
|
||||
# """
|
||||
|
||||
# def _run_server() -> None:
|
||||
# """Starts the Letta server in a background thread."""
|
||||
# load_dotenv() # Load environment variables from .env file
|
||||
# from letta.server.rest_api.app import start_server
|
||||
|
||||
# start_server(debug=True)
|
||||
|
||||
# # Retrieve server URL from environment, or default to localhost
|
||||
# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# # If no environment variable is set, start the server in a background thread
|
||||
# if not os.getenv("LETTA_SERVER_URL"):
|
||||
# thread = threading.Thread(target=_run_server, daemon=True)
|
||||
# thread.start()
|
||||
# time.sleep(5) # Allow time for the server to start
|
||||
|
||||
# return url
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def client(server_url: str) -> Letta:
|
||||
# """
|
||||
# Creates and returns a synchronous Letta REST client for testing.
|
||||
# """
|
||||
# client_instance = Letta(base_url=server_url)
|
||||
# yield client_instance
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def async_client(server_url: str) -> AsyncLetta:
|
||||
# """
|
||||
# Creates and returns an asynchronous Letta REST client for testing.
|
||||
# """
|
||||
# async_client_instance = AsyncLetta(base_url=server_url)
|
||||
# yield async_client_instance
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def roll_dice_tool(client: Letta) -> Tool:
|
||||
# """
|
||||
# Registers a simple roll dice tool with the provided client.
|
||||
|
||||
# The tool simulates rolling a six-sided die but returns a fixed result.
|
||||
# """
|
||||
|
||||
# def roll_dice() -> str:
|
||||
# """
|
||||
# Simulates rolling a die.
|
||||
|
||||
# Returns:
|
||||
# str: The roll result.
|
||||
# """
|
||||
# # Note: The result here is intentionally incorrect for demonstration purposes.
|
||||
# return "Rolled a 10!"
|
||||
|
||||
# tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
# yield tool
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
# """
|
||||
# Creates and returns an agent state for testing with a pre-configured agent.
|
||||
# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
# """
|
||||
# agent_state_instance = client.agents.create(
|
||||
# name="supervisor",
|
||||
# include_base_tools=True,
|
||||
# tool_ids=[roll_dice_tool.id],
|
||||
# model="openai/gpt-4o",
|
||||
# embedding="letta/letta-free",
|
||||
# tags=["supervisor"],
|
||||
# include_base_tool_rules=True,
|
||||
|
||||
# )
|
||||
# yield agent_state_instance
|
||||
|
||||
|
||||
# # Goal is to test that when an Agent is created with a `response_format`, that the response
|
||||
# # of `send_message` is in the correct format. This will be done by modifying the agent's
|
||||
# # `send_message` tool so that it returns a format based on what is passed in.
|
||||
# #
|
||||
# # `response_format` is an optional field
|
||||
# # if `response_format.type` is `text`, then the schema does not change
|
||||
# # if `response_format.type` is `json_object`, then the schema is a dict
|
||||
# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema
|
||||
|
||||
|
||||
# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}]
|
||||
|
||||
# # ------------------------------
|
||||
# # Test Cases
|
||||
# # ------------------------------
|
||||
|
||||
# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None:
|
||||
# """Test client send_message with response_format='json_object'."""
|
||||
# client.agents.modify(agent.id, response_format={"type": "text"})
|
||||
|
||||
# response = client.agents.messages.create_stream(
|
||||
# agent_id=agent.id,
|
||||
# messages=USER_MESSAGE,
|
||||
# )
|
||||
# messages = list(response)
|
||||
# assert isinstance(messages[-1], AssistantMessage)
|
||||
# assert isinstance(messages[-1].content, str)
|
||||
|
||||
|
||||
# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None:
|
||||
# """Test client send_message with response_format='json_object'."""
|
||||
# client.agents.modify(agent.id, response_format={"type": "json_object"})
|
||||
|
||||
# response = client.agents.messages.create_stream(
|
||||
# agent_id=agent.id,
|
||||
# messages=USER_MESSAGE,
|
||||
# )
|
||||
# messages = list(response)
|
||||
# assert isinstance(messages[-1], AssistantMessage)
|
||||
# assert isinstance(messages[-1].content, dict)
|
||||
|
||||
|
||||
# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None:
|
||||
# """Test client send_message with response_format='json_schema' and a valid schema."""
|
||||
# client.agents.modify(agent.id, response_format={
|
||||
# "type": "json_schema",
|
||||
# "json_schema": {
|
||||
# "name": "reasoning_schema",
|
||||
# "schema": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "steps": {
|
||||
# "type": "array",
|
||||
# "items": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "explanation": { "type": "string" },
|
||||
# "output": { "type": "string" }
|
||||
# },
|
||||
# "required": ["explanation", "output"],
|
||||
# "additionalProperties": False
|
||||
# }
|
||||
# },
|
||||
# "final_answer": { "type": "string" }
|
||||
# },
|
||||
# "required": ["steps", "final_answer"],
|
||||
# "additionalProperties": True
|
||||
# },
|
||||
# "strict": True
|
||||
# }
|
||||
# })
|
||||
# response = client.agents.messages.create_stream(
|
||||
# agent_id=agent.id,
|
||||
# messages=USER_MESSAGE,
|
||||
# )
|
||||
# messages = list(response)
|
||||
|
||||
# assert isinstance(messages[-1], AssistantMessage)
|
||||
# assert isinstance(messages[-1].content, dict)
|
||||
|
||||
|
||||
# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None:
|
||||
# # """Test client send_message with an invalid json_schema (should error or fallback)."""
|
||||
# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}}
|
||||
# # client.agents.modify(agent.id, response_format="json_schema")
|
||||
# # result: Any = client.agents.send_message(agent.id, "Test invalid schema")
|
||||
# # assert result is None or "error" in str(result).lower()
|
||||
Reference in New Issue
Block a user