feat: rename function to tool in sdk (#2288)

Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
cthomas
2024-12-19 12:12:58 -08:00
committed by GitHub
parent 5f2ba44e93
commit 7d5be32a59
16 changed files with 202 additions and 164 deletions

View File

@@ -75,6 +75,10 @@ def nb_print(messages):
return_data = json.loads(msg.function_return)
if "message" in return_data and return_data["message"] == "None":
continue
if msg.message_type == "tool_return_message":
return_data = json.loads(msg.tool_return)
if "message" in return_data and return_data["message"] == "None":
continue
title = msg.message_type.replace("_", " ").upper()
html_output += f"""
@@ -94,11 +98,17 @@ def get_formatted_content(msg):
elif msg.message_type == "function_call":
args = format_json(msg.function_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "tool_call_message":
args = format_json(msg.tool_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "function_return":
return_value = format_json(msg.function_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "tool_return_message":
return_value = format_json(msg.tool_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "user_message":
if is_json(msg.message):
return f'<div class="content">{format_json(msg.message)}</div>'

View File

@@ -2,7 +2,7 @@ import os
import uuid
from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_send_message_with_keyword,
@@ -116,9 +116,9 @@ def main():
# 6. Here, we thoroughly check the correctness of the response
tool_names += ["send_message"] # Add send message because we expect this to be called at the end
for m in response.messages:
if isinstance(m, FunctionCallMessage):
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]

View File

@@ -8,8 +8,8 @@ from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
FunctionCallMessage,
FunctionReturn,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
)
from letta.schemas.letta_response import LettaStreamingResponse
@@ -55,10 +55,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
chunk_data = json.loads(sse.data)
if "internal_monologue" in chunk_data:
yield InternalMonologue(**chunk_data)
elif "function_call" in chunk_data:
yield FunctionCallMessage(**chunk_data)
elif "function_return" in chunk_data:
yield FunctionReturn(**chunk_data)
elif "tool_call" in chunk_data:
yield ToolCallMessage(**chunk_data)
elif "tool_return" in chunk_data:
yield ToolReturnMessage(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
else:

View File

@@ -131,16 +131,16 @@ class LettaMessageError(LettaError):
return f"{error_msg}\n\n{message_json}"
class MissingFunctionCallError(LettaMessageError):
"""Error raised when a message is missing a function call."""
class MissingToolCallError(LettaMessageError):
"""Error raised when a message is missing a tool call."""
default_error_message = "The message is missing a function call."
default_error_message = "The message is missing a tool call."
class InvalidFunctionCallError(LettaMessageError):
"""Error raised when a message uses an invalid function call."""
class InvalidToolCallError(LettaMessageError):
"""Error raised when a message uses an invalid tool call."""
default_error_message = "The message uses an invalid function call or has improper usage of a function call."
default_error_message = "The message uses an invalid tool call or has improper usage of a tool call."
class MissingInnerMonologueError(LettaMessageError):

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, field_serializer, field_validator
class LettaMessage(BaseModel):
"""
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, function calls, and function returns in a simplified format that does not include additional information other than the content and timestamp.
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
Attributes:
id (str): The ID of the message
@@ -74,18 +74,18 @@ class InternalMonologue(LettaMessage):
internal_monologue: str
class FunctionCall(BaseModel):
class ToolCall(BaseModel):
name: str
arguments: str
function_call_id: str
tool_call_id: str
class FunctionCallDelta(BaseModel):
class ToolCallDelta(BaseModel):
name: Optional[str]
arguments: Optional[str]
function_call_id: Optional[str]
tool_call_id: Optional[str]
# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
@@ -97,50 +97,84 @@ class FunctionCallDelta(BaseModel):
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)
class FunctionCallMessage(LettaMessage):
class ToolCallMessage(LettaMessage):
"""
A message representing a request to call a function (generated by the LLM to trigger function execution).
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
Attributes:
function_call (Union[FunctionCall, FunctionCallDelta]): The function call
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["function_call"] = "function_call"
function_call: Union[FunctionCall, FunctionCallDelta]
message_type: Literal["tool_call_message"] = "tool_call_message"
tool_call: Union[ToolCall, ToolCallDelta]
# NOTE: this is required for the FunctionCallDelta exclude_none to work correctly
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
def model_dump(self, *args, **kwargs):
kwargs["exclude_none"] = True
data = super().model_dump(*args, **kwargs)
if isinstance(data["function_call"], dict):
data["function_call"] = {k: v for k, v in data["function_call"].items() if v is not None}
if isinstance(data["tool_call"], dict):
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
return data
class Config:
json_encoders = {
FunctionCallDelta: lambda v: v.model_dump(exclude_none=True),
FunctionCall: lambda v: v.model_dump(exclude_none=True),
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
ToolCall: lambda v: v.model_dump(exclude_none=True),
}
# NOTE: this is required to cast dicts into FunctionCallMessage objects
# NOTE: this is required to cast dicts into ToolCallMessage objects
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
# (instead of properly casting to FunctionCallDelta instead of FunctionCall)
@field_validator("function_call", mode="before")
# (instead of properly casting to ToolCallDelta instead of ToolCall)
@field_validator("tool_call", mode="before")
@classmethod
def validate_function_call(cls, v):
def validate_tool_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v and "function_call_id" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"])
elif "name" in v or "arguments" in v or "function_call_id" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id"))
if "name" in v and "arguments" in v and "tool_call_id" in v:
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
elif "name" in v or "arguments" in v or "tool_call_id" in v:
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
else:
raise ValueError("function_call must contain either 'name' or 'arguments'")
raise ValueError("tool_call must contain either 'name' or 'arguments'")
return v
class FunctionReturn(LettaMessage):
class ToolReturnMessage(LettaMessage):
"""
A message representing the return value of a tool call (generated by Letta executing the requested tool).
Attributes:
tool_return (str): The return value of the tool
status (Literal["success", "error"]): The status of the tool call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
tool_call_id (str): A unique identifier for the tool call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
stderr (Optional[List(str)]): Captured stderr from the tool invocation
"""
message_type: Literal["tool_return_message"] = "tool_return_message"
tool_return: str
status: Literal["success", "error"]
tool_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
class AssistantMessage(LettaMessage):
message_type: Literal["assistant_message"] = "assistant_message"
assistant_message: str
class LegacyFunctionCallMessage(LettaMessage):
function_call: str
class LegacyFunctionReturn(LettaMessage):
"""
A message representing the return value of a function call (generated by Letta executing the requested function).
@@ -162,22 +196,10 @@ class FunctionReturn(LettaMessage):
stderr: Optional[List[str]] = None
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
class AssistantMessage(LettaMessage):
message_type: Literal["assistant_message"] = "assistant_message"
assistant_message: str
class LegacyFunctionCallMessage(LettaMessage):
function_call: str
LegacyLettaMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]
LegacyLettaMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
LettaMessageUnion = Annotated[
Union[SystemMessage, UserMessage, InternalMonologue, FunctionCallMessage, FunctionReturn, AssistantMessage],
Union[SystemMessage, UserMessage, InternalMonologue, ToolCallMessage, ToolReturnMessage, AssistantMessage],
Field(discriminator="message_type"),
]

View File

@@ -43,11 +43,17 @@ class LettaResponse(BaseModel):
elif msg.message_type == "function_call":
args = format_json(msg.function_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "tool_call_message":
args = format_json(msg.tool_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "function_return":
return_value = format_json(msg.function_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "tool_return_message":
return_value = format_json(msg.tool_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "user_message":
if is_json(msg.message):
return f'<div class="content">{format_json(msg.message)}</div>'

View File

@@ -16,9 +16,9 @@ from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
FunctionCallMessage,
FunctionReturn,
ToolCall as LettaToolCall,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
LettaMessage,
SystemMessage,
@@ -172,18 +172,18 @@ class Message(BaseMessage):
)
else:
messages.append(
FunctionCallMessage(
ToolCallMessage(
id=self.id,
date=self.created_at,
function_call=FunctionCall(
tool_call=LettaToolCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
function_call_id=tool_call.id,
tool_call_id=tool_call.id,
),
)
)
elif self.role == MessageRole.tool:
# This is type FunctionReturn
# This is type ToolReturnMessage
# Try to interpret the function return, recall that this is how we packaged:
# def package_function_response(was_success, response_string, timestamp=None):
# formatted_time = get_local_time() if timestamp is None else timestamp
@@ -208,12 +208,12 @@ class Message(BaseMessage):
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
FunctionReturn(
ToolReturnMessage(
id=self.id,
date=self.created_at,
function_return=self.text,
tool_return=self.text,
status=status_enum,
function_call_id=self.tool_call_id,
tool_call_id=self.tool_call_id,
)
)
elif self.role == MessageRole.user:

View File

@@ -12,10 +12,10 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
FunctionCallDelta,
FunctionCallMessage,
FunctionReturn,
ToolCall,
ToolCallDelta,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
LegacyFunctionCallMessage,
LegacyLettaMessage,
@@ -411,7 +411,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def _process_chunk_to_letta_style(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime
) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]:
) -> Optional[Union[InternalMonologue, ToolCallMessage, AssistantMessage]]:
"""
Example data from non-streaming response looks like:
@@ -442,7 +442,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if self.inner_thoughts_in_kwargs:
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard ToolCallMessage passthrough mode
# Track the function name while streaming
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
@@ -474,7 +474,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
assistant_message=cleaned_func_args,
)
# otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
else:
tool_call_delta = {}
if tool_call.id:
@@ -485,13 +485,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
tool_call_id=tool_call_delta.get("id"),
),
)
@@ -531,7 +531,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
self.function_args_buffer += updates_main_json
# If we have main_json, we should output a FunctionCallMessage
# If we have main_json, we should output a ToolCallMessage
elif updates_main_json:
# If there's something in the function_name buffer, we should release it first
@@ -539,13 +539,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# however the frontend may expect name first, then args, so to be
# safe we'll output name first in a separate chunk
if self.function_name_buffer:
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
function_call_id=self.function_id_buffer,
tool_call_id=self.function_id_buffer,
),
)
# Clear the buffer
@@ -561,20 +561,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.function_args_buffer += updates_main_json
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a FunctionCallMessage
# output the arguments chunk as a ToolCallMessage
else:
# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=None,
arguments=combined_chunk,
function_call_id=self.function_id_buffer,
tool_call_id=self.function_id_buffer,
),
)
# clear buffer
@@ -582,13 +582,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
function_call_id=self.function_id_buffer,
tool_call_id=self.function_id_buffer,
),
)
self.function_id_buffer = None
@@ -608,10 +608,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# # if tool_call.function.name:
# # tool_call_delta["name"] = tool_call.function.name
# processed_chunk = FunctionCallMessage(
# processed_chunk = ToolCallMessage(
# id=message_id,
# date=message_date,
# function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# tool_call=ToolCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# )
else:
@@ -642,10 +642,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# if tool_call.function.name:
# tool_call_delta["name"] = tool_call.function.name
# processed_chunk = FunctionCallMessage(
# processed_chunk = ToolCallMessage(
# id=message_id,
# date=message_date,
# function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# tool_call=ToolCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# )
# elif False and self.inner_thoughts_in_kwargs and tool_call.function:
@@ -682,13 +682,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# If it does match, start processing the value (stringified-JSON string
# And with each new chunk, output it as a chunk of type InternalMonologue
# If the key doesn't match, then flush the buffer as a single FunctionCallMessage chunk
# If the key doesn't match, then flush the buffer as a single ToolCallMessage chunk
# If we're reading a value
# If we're reading the inner thoughts value, we output chunks of type InternalMonologue
# Otherwise, do simple chunks of FunctionCallMessage
# Otherwise, do simple chunks of ToolCallMessage
else:
@@ -701,13 +701,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
tool_call_id=tool_call_delta.get("id"),
),
)
@@ -911,13 +911,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
assistant_message=func_args[self.assistant_message_tool_kwarg],
)
else:
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
tool_call=ToolCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
function_call_id=function_call.id,
tool_call_id=function_call.id,
),
)
@@ -942,24 +942,24 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
new_message = ToolReturnMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
tool_return=msg,
status="success",
function_call_id=msg_obj.tool_call_id,
tool_call_id=msg_obj.tool_call_id,
)
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
new_message = ToolReturnMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
tool_return=msg,
status="error",
function_call_id=msg_obj.tool_call_id,
tool_call_id=msg_obj.tool_call_id,
)
else:

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import FunctionCall, LettaMessage
from letta.schemas.letta_message import ToolCall, LettaMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
@@ -94,7 +94,7 @@ async def create_chat_completion(
created_at = None
for letta_msg in response_messages.messages:
assert isinstance(letta_msg, LettaMessage)
if isinstance(letta_msg, FunctionCall):
if isinstance(letta_msg, ToolCall):
if letta_msg.name and letta_msg.name == "send_message":
try:
letta_function_call_args = json.loads(letta_msg.arguments)

View File

@@ -7,7 +7,7 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.errors import LettaToolCreateError
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.letta_message import FunctionReturn
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
from letta.schemas.user import User
from letta.server.rest_api.utils import get_letta_server
@@ -163,7 +163,7 @@ def upsert_base_tools(
return server.tool_manager.upsert_base_tools(actor=actor)
@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")
@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source")
def run_tool_from_source(
server: SyncServer = Depends(get_letta_server),
request: ToolRunFromSource = Body(...),

View File

@@ -47,7 +47,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import FunctionReturn, LettaMessage
from letta.schemas.letta_message import ToolReturnMessage, LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary,
@@ -1350,7 +1350,7 @@ class SyncServer(Server):
tool_source: str,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
) -> FunctionReturn:
) -> ToolReturnMessage:
"""Run a tool from source code"""
try:
@@ -1374,24 +1374,24 @@ class SyncServer(Server):
# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
return FunctionReturn(
return ToolReturnMessage(
id="null",
function_call_id="null",
tool_call_id="null",
date=get_utc_time(),
status=sandbox_run_result.status,
function_return=str(sandbox_run_result.func_return),
tool_return=str(sandbox_run_result.func_return),
stdout=sandbox_run_result.stdout,
stderr=sandbox_run_result.stderr,
)
except Exception as e:
func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e))
return FunctionReturn(
return ToolReturnMessage(
id="null",
function_call_id="null",
tool_call_id="null",
date=get_utc_time(),
status="error",
function_return=func_return,
tool_return=func_return,
stdout=[],
stderr=[traceback.format_exc()],
)

View File

@@ -15,9 +15,9 @@ from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.embeddings import embedding_model
from letta.errors import (
InvalidFunctionCallError,
InvalidToolCallError,
InvalidInnerMonologueError,
MissingFunctionCallError,
MissingToolCallError,
MissingInnerMonologueError,
)
from letta.llm_api.llm_api_tools import create
@@ -25,7 +25,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import (
FunctionCallMessage,
ToolCallMessage,
InternalMonologue,
LettaMessage,
)
@@ -377,23 +377,23 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
# Find first instance of send_message
target_message = None
for message in messages:
if isinstance(message, FunctionCallMessage) and message.function_call.name == "send_message":
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
target_message = message
break
# No messages found with `send_messages`
if target_message is None:
raise MissingFunctionCallError(messages=messages, explanation="Missing `send_message` function call")
raise MissingToolCallError(messages=messages, explanation="Missing `send_message` function call")
send_message_function_call = target_message.function_call
send_message_function_call = target_message.tool_call
try:
arguments = json.loads(send_message_function_call.arguments)
except:
raise InvalidFunctionCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON")
raise InvalidToolCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON")
# Message field not in send_message
if "message" not in arguments:
raise InvalidFunctionCallError(
raise InvalidToolCallError(
messages=[target_message], explanation=f"send_message function call does not have required field `message`"
)
@@ -403,16 +403,16 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
arguments["message"] = arguments["message"].lower()
if not keyword in arguments["message"]:
raise InvalidFunctionCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None:
for message in messages:
if isinstance(message, FunctionCallMessage) and message.function_call.name == function_name:
if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name:
# Found it, do nothing
return
raise MissingFunctionCallError(
raise MissingToolCallError(
messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}"
)
@@ -446,7 +446,7 @@ def assert_contains_valid_function_call(
if (hasattr(message, "function_call") and message.function_call is not None) and (
hasattr(message, "tool_calls") and message.tool_calls is not None
):
raise InvalidFunctionCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message")
raise InvalidToolCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message")
elif hasattr(message, "function_call") and message.function_call is not None:
function_call = message.function_call
elif hasattr(message, "tool_calls") and message.tool_calls is not None:
@@ -455,10 +455,10 @@ def assert_contains_valid_function_call(
function_call = message.tool_calls[0].function
else:
# Throw a missing function call error
raise MissingFunctionCallError(messages=[message])
raise MissingToolCallError(messages=[message])
if function_call_validator and not function_call_validator(function_call):
raise InvalidFunctionCallError(messages=[message], explanation=validation_failure_summary)
raise InvalidToolCallError(messages=[message], explanation=validation_failure_summary)
def assert_inner_monologue_is_valid(message: Message) -> None:

View File

@@ -3,7 +3,7 @@ import uuid
import pytest
from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
@@ -115,9 +115,9 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):
tool_names = [t.name for t in [t1, t2, t3, t4]]
tool_names += ["send_message"]
for m in response.messages:
if isinstance(m, FunctionCallMessage):
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]
@@ -220,9 +220,9 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
tool_names = [t.name for t in [t1, t2]]
tool_names += ["send_message"]
for m in messages:
if isinstance(m, FunctionCallMessage):
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]
@@ -273,9 +273,9 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
# Check ordering of tool calls
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
for m in response.messages:
if isinstance(m, FunctionCallMessage):
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]

View File

@@ -15,7 +15,7 @@ from letta.schemas.agent import AgentState
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import JobStatus
from letta.schemas.letta_message import FunctionReturn
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
from letta.utils import create_random_username
@@ -365,12 +365,12 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
response_message = None
for message in response.messages:
if isinstance(message, FunctionReturn):
if isinstance(message, ToolReturnMessage):
response_message = message
break
assert response_message, "FunctionReturn message not found in response"
res = response_message.function_return
assert response_message, "ToolReturnMessage message not found in response"
res = response_message.tool_return
assert "function output was truncated " in res
# TODO: Re-enable later

View File

@@ -18,8 +18,8 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
FunctionReturn,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
LettaMessage,
SystemMessage,
@@ -172,8 +172,8 @@ def test_agent_interactions(mock_e2b_api_key_none, client: Union[LocalClient, RE
SystemMessage,
UserMessage,
InternalMonologue,
FunctionCallMessage,
FunctionReturn,
ToolCallMessage,
ToolReturnMessage,
AssistantMessage,
], f"Unexpected message type: {type(letta_message)}"
@@ -258,7 +258,7 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "":
inner_thoughts_exist = True
inner_thoughts_count += 1
if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message":
if isinstance(chunk, ToolCallMessage) and chunk.tool_call and chunk.tool_call.name == "send_message":
send_message_ran = True
if isinstance(chunk, MessageStreamStatus):
if chunk == MessageStreamStatus.done:
@@ -534,7 +534,7 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
print("Messages=", message_response)
assert isinstance(message_response, LettaResponse)
assert isinstance(message_response.messages[-1], FunctionReturn)
assert isinstance(message_response.messages[-1], ToolReturnMessage)
message = message_response.messages[-1]
new_text = "This exact string would never show up in the message???"

View File

@@ -10,8 +10,8 @@ from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import (
FunctionCallMessage,
FunctionReturn,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
LettaMessage,
SystemMessage,
@@ -677,14 +677,14 @@ def _test_get_messages_letta_format(
print(f"Assistant Message at {i}: {type(letta_message)}")
if reverse:
# Reverse handling: FunctionCallMessages come first
# Reverse handling: ToolCallMessage come first
if message.tool_calls:
for tool_call in message.tool_calls:
try:
json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
assert isinstance(letta_message, ToolCallMessage)
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
@@ -710,9 +710,9 @@ def _test_get_messages_letta_format(
json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
assert isinstance(letta_message, ToolCallMessage)
assert tool_call.function.name == letta_message.tool_call.name
assert tool_call.function.arguments == letta_message.tool_call.arguments
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
@@ -729,8 +729,8 @@ def _test_get_messages_letta_format(
letta_message_index += 1
elif message.role == MessageRole.tool:
assert isinstance(letta_message, FunctionReturn)
assert message.text == letta_message.function_return
assert isinstance(letta_message, ToolReturnMessage)
assert message.text == letta_message.tool_return
letta_message_index += 1
else:
@@ -802,7 +802,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Hello, world!", result.function_return
assert result.tool_return == "Ingested message Hello, world!", result.tool_return
assert not result.stdout
assert not result.stderr
@@ -815,7 +815,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return
assert result.tool_return == "Ingested message Well well well", result.tool_return
assert not result.stdout
assert not result.stderr
@@ -828,8 +828,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
)
print(result)
assert result.status == "error"
assert "Error" in result.function_return, result.function_return
assert "missing 1 required positional argument" in result.function_return, result.function_return
assert "Error" in result.tool_return, result.tool_return
assert "missing 1 required positional argument" in result.tool_return, result.tool_return
assert not result.stdout
assert result.stderr
assert "missing 1 required positional argument" in result.stderr[0]
@@ -844,7 +844,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return
assert result.tool_return == "Ingested message Well well well", result.tool_return
assert result.stdout
assert "I'm a distractor" in result.stdout[0]
assert not result.stderr
@@ -859,7 +859,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return
assert result.tool_return == "Ingested message Well well well", result.tool_return
assert result.stdout
assert "I'm a distractor" in result.stdout[0]
assert not result.stderr
@@ -874,7 +874,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
)
print(result)
assert result.status == "success"
assert result.function_return == str(None), result.function_return
assert result.tool_return == str(None), result.tool_return
assert result.stdout
assert "I'm a distractor" in result.stdout[0]
assert not result.stderr