feat: Add tool calling to fast chat completions (#1109)

This commit is contained in:
Matthew Zhou
2025-02-25 15:13:35 -08:00
committed by GitHub
parent bb2bf65668
commit 71805b2a22
10 changed files with 507 additions and 43 deletions

View File

@@ -52,6 +52,8 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
@@ -59,6 +61,11 @@ MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_t
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL_KWARG = "message"
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
REQUEST_HEARTBEAT_PARAM = "request_heartbeat"
# Structured output models
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}

View File

@@ -6,10 +6,11 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.settings import tool_settings
def get_composio_api_key(actor: User, logger: Logger) -> Optional[str]:
def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Optional[str]:
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
if not api_keys:
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
if logger:
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
if tool_settings.composio_api_key:
return tool_settings.composio_api_key
else:

View File

@@ -0,0 +1,171 @@
from collections import OrderedDict
from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.utils import get_friendly_error_msg
def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Enables strict mode for a tool schema by setting 'strict' to True and
disallowing additional properties in the parameters.
Args:
tool_schema (Dict[str, Any]): The original tool schema.
Returns:
Dict[str, Any]: A new tool schema with strict mode enabled.
"""
schema = tool_schema.copy()
# Enable strict mode
schema["strict"] = True
# Ensure parameters is a valid dictionary
parameters = schema.get("parameters", {})
if isinstance(parameters, dict) and parameters.get("type") == "object":
# Set additionalProperties to False
parameters["additionalProperties"] = False
schema["parameters"] = parameters
return schema
def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Adds a `pre_execution_message` parameter to a tool schema to prompt a natural, human-like message before executing the tool.
Args:
tool_schema (Dict[str, Any]): The original tool schema.
Returns:
Dict[str, Any]: A new tool schema with the `pre_execution_message` field added at the beginning.
"""
schema = tool_schema.copy()
parameters = schema.get("parameters", {})
if not isinstance(parameters, dict) or parameters.get("type") != "object":
return schema # Do not modify if schema is not valid
properties = parameters.get("properties", {})
required = parameters.get("required", [])
# Define the new `pre_execution_message` field with a refined description
pre_execution_message_field = {
"type": "string",
"description": (
"A concise message to be uttered before executing this tool. "
"This should sound natural, as if a person is casually announcing their next action."
"You MUST also include punctuation at the end of this message."
),
}
# Ensure the pre-execution message is the first field in properties
updated_properties = OrderedDict()
updated_properties[PRE_EXECUTION_MESSAGE_ARG] = pre_execution_message_field
updated_properties.update(properties) # Retain all existing properties
# Ensure pre-execution message is the first required field
if PRE_EXECUTION_MESSAGE_ARG not in required:
required = [PRE_EXECUTION_MESSAGE_ARG] + required
# Update the schema with ordered properties and required list
schema["parameters"] = {
**parameters,
"properties": dict(updated_properties), # Convert OrderedDict back to dict
"required": required,
}
return schema
def remove_request_heartbeat(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Removes the `request_heartbeat` parameter from a tool schema if it exists.
Args:
tool_schema (Dict[str, Any]): The original tool schema.
Returns:
Dict[str, Any]: A new tool schema without `request_heartbeat`.
"""
schema = tool_schema.copy()
parameters = schema.get("parameters", {})
if isinstance(parameters, dict):
properties = parameters.get("properties", {})
required = parameters.get("required", [])
# Remove the `request_heartbeat` property if it exists
if "request_heartbeat" in properties:
properties.pop("request_heartbeat")
# Remove `request_heartbeat` from required fields if present
if "request_heartbeat" in required:
required = [r for r in required if r != "request_heartbeat"]
# Update parameters with modified properties and required list
schema["parameters"] = {**parameters, "properties": properties, "required": required}
return schema
# TODO: Deprecate the `execute_external_tool` function on the agent body
def execute_external_tool(
agent_state: AgentState,
function_name: str,
function_args: dict,
target_letta_tool: Tool,
actor: User,
allow_agent_state_modifications: bool = False,
) -> tuple[Any, Optional[SandboxRunResult]]:
# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
if target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
# Get entity ID from the agent_state
entity_id = None
for env_var in agent_state.tool_exec_environment_variables:
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
entity_id = env_var.value
# Get composio_api_key
composio_api_key = get_composio_api_key(actor=actor)
function_response = execute_composio_action(
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)
return function_response, None
elif target_letta_tool.tool_type == ToolType.CUSTOM:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
function_args = coerce_dict_args_by_annotations(function_args, annotations)
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
# TODO: This is only temporary, can remove after we publish a pip package with this object
if allow_agent_state_modifications:
agent_state_copy = agent_state.__deepcopy__()
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
else:
agent_state_copy = None
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
# TODO: Bring this back
# if allow_agent_state_modifications and updated_agent_state is not None:
# self.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
function_response = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
return function_response, None

View File

@@ -99,7 +99,7 @@ class ChatCompletionRequest(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/create"""
model: str
messages: List[ChatMessage]
messages: List[Union[ChatMessage, Dict]]
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = False

View File

@@ -1,19 +1,39 @@
import asyncio
import json
import uuid
from typing import TYPE_CHECKING, List, Optional, Union
import httpx
import openai
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
from openai.types.chat.completion_create_params import CompletionCreateParams
from starlette.concurrency import run_in_threadpool
from letta.agent import Agent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG
from letta.helpers.tool_execution_helper import (
add_pre_execution_message,
enable_strict_mode,
execute_external_tool,
remove_request_heartbeat,
)
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import (
AssistantMessage,
ChatCompletionRequest,
Tool,
ToolCall,
ToolCallFunction,
ToolMessage,
UserMessage,
)
from letta.schemas.user import User
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
# TODO this belongs in a controller!
from letta.server.rest_api.utils import (
@@ -52,20 +72,53 @@ async def create_fast_chat_completions(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
# TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = str(completion_request.get("user", None))
if agent_id is None:
error_msg = "Must pass agent_id in the 'user' field"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
model = completion_request.get("model")
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.llm_config.model_endpoint_type != "openai":
raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
# Convert Letta messages to OpenAI messages
in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)
openai_messages = convert_letta_messages_to_openai(in_context_messages)
# Also parse user input from completion_request and append
input_message = get_messages_from_completion_request(completion_request)[-1]
openai_messages.append(input_message)
# Tools we allow this agent to call
tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
# Initial request
openai_request = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=openai_messages,
# TODO: This nested thing here is so ugly, need to refactor
tools=(
[
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
for t in tools
]
if tools
else None
),
tool_choice="auto",
user=user_id,
max_completion_tokens=agent_state.llm_config.max_tokens,
temperature=agent_state.llm_config.temperature,
stream=True,
)
# Create the OpenAI async client
client = openai.AsyncClient(
api_key=model_settings.openai_api_key,
max_retries=0,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
@@ -75,38 +128,175 @@ async def create_fast_chat_completions(
),
)
# Magic message manipulating
input_message = get_messages_from_completion_request(completion_request)[-1]
completion_request.pop("messages")
# Get in context messages
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages)
openai_dict_in_context_messages.append(input_message)
# The messages we want to persist to the Letta agent
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
message_db_queue = [user_message]
async def event_stream():
# TODO: Factor this out into separate interface
response_accumulator = []
"""
A function-calling loop:
- We stream partial tokens.
- If we detect a tool call (finish_reason="tool_calls"), we parse it,
add two messages to the conversation:
(a) assistant message with tool_calls referencing the same ID
(b) a tool message referencing that ID, containing the tool result.
- Re-invoke the OpenAI request with updated conversation, streaming again.
- End when finish_reason="stop" or no more tool calls.
"""
stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages)
# We'll keep updating this conversation in a loop
conversation = openai_messages[:]
async with stream:
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
# TODO: This does not support tool calling right now
response_accumulator.append(chunk.choices[0].delta.content)
yield f"data: {chunk.model_dump_json()}\n\n"
while True:
# Make the streaming request to OpenAI
stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True))
# Construct messages
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
assistant_message = create_assistant_message_from_openai_response(
response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor
)
content_buffer = []
tool_call_name = None
tool_call_args_str = ""
tool_call_id = None
tool_call_happened = False
finish_reason_stop = False
optimistic_json_parser = OptimisticJSONParser(strict=True)
current_parsed_json_result = {}
async with stream:
async for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
finish_reason = choice.finish_reason # "tool_calls", "stop", or None
if delta.content:
content_buffer.append(delta.content)
yield f"data: {chunk.model_dump_json()}\n\n"
# CASE B: Partial tool call info
if delta.tool_calls:
# Typically there's only one in delta.tool_calls
tc = delta.tool_calls[0]
if tc.function.name:
tool_call_name = tc.function.name
if tc.function.arguments:
tool_call_args_str += tc.function.arguments
# See if we can stream out the pre-execution message
parsed_args = optimistic_json_parser.parse(tool_call_args_str)
if parsed_args.get(
PRE_EXECUTION_MESSAGE_ARG
) and current_parsed_json_result.get( # Ensure key exists and is not None/empty
PRE_EXECUTION_MESSAGE_ARG
) != parsed_args.get(
PRE_EXECUTION_MESSAGE_ARG
):
# Only stream if there's something new to stream
# We do this way to avoid hanging JSON at the end of the stream, e.g. '}'
if parsed_args != current_parsed_json_result:
current_parsed_json_result = parsed_args
synthetic_chunk = ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created,
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=tc.function.arguments, role="assistant"),
finish_reason=None,
)
],
)
yield f"data: {synthetic_chunk.model_dump_json()}\n\n"
# We might generate a unique ID for the tool call
if tc.id:
tool_call_id = tc.id
# Check finish_reason
if finish_reason == "tool_calls":
tool_call_happened = True
break
elif finish_reason == "stop":
finish_reason_stop = True
break
if content_buffer:
# We treat that partial text as an assistant message
content = "".join(content_buffer)
conversation.append({"role": "assistant", "content": content})
# Create an assistant message here to persist later
assistant_message = create_assistant_message_from_openai_response(
response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor
)
message_db_queue.append(assistant_message)
if tool_call_happened:
# Parse the tool call arguments
try:
tool_args = json.loads(tool_call_args_str)
except json.JSONDecodeError:
tool_args = {}
if not tool_call_id:
# If no tool_call_id given by the model, generate one
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
# 1) Insert the "assistant" message with the tool_calls field
# referencing the same tool_call_id
assistant_tool_call_msg = AssistantMessage(
content=None,
tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))],
)
conversation.append(assistant_tool_call_msg.model_dump())
# 2) Execute the tool
target_tool = next((x for x in tools if x.name == tool_call_name), None)
if not target_tool:
# Tool not found, handle error
yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n"
break
try:
tool_result, _ = execute_external_tool(
agent_state=agent_state,
function_name=tool_call_name,
function_args=tool_args,
target_letta_tool=target_tool,
actor=actor,
allow_agent_state_modifications=False,
)
except Exception as e:
tool_result = f"Failed to call tool. Error: {e}"
# 3) Insert the "tool" message referencing the same tool_call_id
tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id)
conversation.append(tool_message.model_dump())
# 4) Add a user message prompting the tool call result summarization
heartbeat_user_message = UserMessage(
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.",
)
conversation.append(heartbeat_user_message.model_dump())
# Now, re-invoke OpenAI with the updated conversation
openai_request.messages = conversation
continue # Start the while loop again
if finish_reason_stop:
# Model is done, no more calls
break
# If we reach here, no tool call, no "stop", but we've ended streaming
# Possibly a model error or some other finish reason. We'll just end.
break
# Persist both in one synchronous DB call, done in a threadpool
await run_in_threadpool(
server.agent_manager.append_to_in_context_messages,
[user_message, assistant_message],
message_db_queue,
agent_id=agent_id,
actor=actor,
)

View File

@@ -7,7 +7,6 @@ from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
import pytz
from fastapi import Header, HTTPException
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
@@ -145,7 +144,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess
Converts a user input message into the internal structured format.
"""
# Generate timestamp in the correct format
now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
now = datetime.now(timezone.utc).isoformat()
# Format message as structured JSON
structured_message = {"type": "user_message", "message": input_message["content"], "time": now}
@@ -197,7 +196,7 @@ def create_assistant_message_from_openai_response(
agent_id=agent_id,
model=model,
tool_calls=[tool_call],
tool_call_id=None,
tool_call_id=tool_call_id,
created_at=datetime.now(timezone.utc),
)

View File

@@ -21,8 +21,10 @@ from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage as PydanticPassage
@@ -613,6 +615,40 @@ class AgentManager:
)
return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor)
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
@enforce_types
def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
"""
Update internal memory object and system prompt if there have been modifications.
Args:
new_memory (Memory): the new memory object to compare to the current memory object
Returns:
modified (bool): whether the memory was updated
"""
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.memory.compile() != new_memory.compile():
# update the blocks (LRW) in the DB
for label in agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value
if updated_value != agent_state.memory.get_block(label).value:
# update the block if it's changed
block_id = agent_state.memory.get_block(label).id
block = self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor)
# refresh memory from DB (using block ids)
agent_state.memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()]
)
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor)
return agent_state
# ======================================================================================================================
# Source Management
# ======================================================================================================================

View File

@@ -107,12 +107,14 @@ class BlockManager:
@enforce_types
def add_default_blocks(self, actor: PydanticUser):
for persona_file in list_persona_files():
text = open(persona_file, "r", encoding="utf-8").read()
with open(persona_file, "r", encoding="utf-8") as f:
text = f.read()
name = os.path.basename(persona_file).replace(".txt", "")
self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor)
for human_file in list_human_files():
text = open(human_file, "r", encoding="utf-8").read()
with open(human_file, "r", encoding="utf-8") as f:
text = f.read()
name = os.path.basename(human_file).replace(".txt", "")
self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor)

View File

@@ -2,6 +2,7 @@ from typing import List, Optional
from sqlalchemy import and_, or_
from letta.log import get_logger
from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
@@ -11,6 +12,8 @@ from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
logger = get_logger(__name__)
class MessageManager:
"""Manager class to handle business logic related to Messages."""
@@ -37,7 +40,7 @@ class MessageManager:
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
if len(results) != len(message_ids):
raise NoResultFound(
logger.warning(
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
)

View File

@@ -14,7 +14,9 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
from letta.schemas.tool import ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
# --- Server Management --- #
@@ -69,9 +71,49 @@ def roll_dice_tool(client):
@pytest.fixture(scope="function")
def agent(client, roll_dice_tool):
def weather_tool(client):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
Parameters:
location (str): The location to get the weather for.
Returns:
str: A formatted string describing the weather in the given location.
Raises:
RuntimeError: If the request to fetch weather data fails.
"""
import requests
url = f"https://wttr.in/{location}?format=%C+%t"
response = requests.get(url)
if response.status_code == 200:
weather_data = response.text
return f"The weather in {location} is {weather_data}."
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = client.create_or_update_tool(func=get_weather)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def composio_gmail_get_profile_tool(default_user):
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
yield tool
@pytest.fixture(scope="function")
def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id])
agent_state = client.create_agent(
name=f"test_compl_{str(uuid.uuid4())[5:]}", tool_ids=[roll_dice_tool.id, weather_tool.id, composio_gmail_get_profile_tool.id]
)
yield agent_state
client.delete_agent(agent_state.id)
@@ -111,6 +153,19 @@ def _assert_valid_chunk(chunk, idx, chunks):
# --- Test Cases --- #
@pytest.mark.parametrize("message", ["What's the weather in SF?"])
@pytest.mark.parametrize("endpoint", ["fast/chat/completions"])
def test_tool_usage_fast_chat_completions(mock_e2b_api_key_none, client, agent, message, endpoint):
"""Tests chat completion streaming via SSE."""
request = _get_chat_request(agent.id, message)
response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers)
for chunk in response:
if isinstance(chunk, ChatCompletionChunk) and chunk.choices:
print(chunk.choices[0].delta.content)
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
@pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"])
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint):